1#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use bytes::BytesMut;
11use mssql_codec::connection::Connection;
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode, TlsStream};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::{MAX_PACKET_SIZE, PacketType};
15use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
16use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
17use tds_protocol::token::{
18 ColMetaData, ColumnData, EnvChange, EnvChangeType, NbcRow, RawRow, Token, TokenParser,
19};
20use tds_protocol::tvp::{
21 TvpColumnDef as TvpWireColumnDef, TvpColumnFlags, TvpEncoder, TvpWireType, encode_tvp_bit,
22 encode_tvp_decimal, encode_tvp_float, encode_tvp_int, encode_tvp_null, encode_tvp_nvarchar,
23 encode_tvp_varbinary,
24};
25use tokio::net::TcpStream;
26use tokio::time::timeout;
27
28use crate::config::Config;
29use crate::error::{Error, Result};
30#[cfg(feature = "otel")]
31use crate::instrumentation::InstrumentationContext;
32use crate::state::{ConnectionState, Disconnected, InTransaction, Ready};
33use crate::statement_cache::StatementCache;
34use crate::stream::{MultiResultStream, QueryStream};
35use crate::transaction::SavePoint;
36
37pub struct Client<S: ConnectionState> {
43 config: Config,
44 _state: PhantomData<S>,
45 connection: Option<ConnectionHandle>,
47 server_version: Option<u32>,
49 current_database: Option<String>,
51 statement_cache: StatementCache,
53 transaction_descriptor: u64,
57 #[cfg(feature = "otel")]
59 instrumentation: InstrumentationContext,
60}
61
62#[allow(dead_code)] enum ConnectionHandle {
70 Tls(Connection<TlsStream<TcpStream>>),
72 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
74 Plain(Connection<TcpStream>),
76}
77
78impl Client<Disconnected> {
79 pub async fn connect(config: Config) -> Result<Client<Ready>> {
90 let max_redirects = config.redirect.max_redirects;
91 let follow_redirects = config.redirect.follow_redirects;
92 let mut attempts = 0;
93 let mut current_config = config;
94
95 loop {
96 attempts += 1;
97 if attempts > max_redirects + 1 {
98 return Err(Error::TooManyRedirects { max: max_redirects });
99 }
100
101 match Self::try_connect(¤t_config).await {
102 Ok(client) => return Ok(client),
103 Err(Error::Routing { host, port }) => {
104 if !follow_redirects {
105 return Err(Error::Routing { host, port });
106 }
107 tracing::info!(
108 host = %host,
109 port = port,
110 attempt = attempts,
111 max_redirects = max_redirects,
112 "following Azure SQL routing redirect"
113 );
114 current_config = current_config.with_host(&host).with_port(port);
115 continue;
116 }
117 Err(e) => return Err(e),
118 }
119 }
120 }
121
122 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
123 tracing::info!(
124 host = %config.host,
125 port = config.port,
126 database = ?config.database,
127 "connecting to SQL Server"
128 );
129
130 let addr = format!("{}:{}", config.host, config.port);
131
132 tracing::debug!("establishing TCP connection to {}", addr);
134 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
135 .await
136 .map_err(|_| Error::ConnectTimeout)?
137 .map_err(|e| Error::Io(Arc::new(e)))?;
138
139 tcp_stream
141 .set_nodelay(true)
142 .map_err(|e| Error::Io(Arc::new(e)))?;
143
144 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
146
147 if tls_mode.is_tls_first() {
149 return Self::connect_tds_8(config, tcp_stream).await;
150 }
151
152 Self::connect_tds_7x(config, tcp_stream).await
154 }
155
156 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
160 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
161
162 let tls_config = TlsConfig::new()
164 .strict_mode(true)
165 .trust_server_certificate(config.trust_server_certificate);
166
167 let tls_connector = TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
168
169 let tls_stream = timeout(
171 config.timeouts.tls_timeout,
172 tls_connector.connect(tcp_stream, &config.host),
173 )
174 .await
175 .map_err(|_| Error::TlsTimeout)?
176 .map_err(|e| Error::Tls(e.to_string()))?;
177
178 tracing::debug!("TLS handshake completed (strict mode)");
179
180 let mut connection = Connection::new(tls_stream);
182
183 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
185 Self::send_prelogin(&mut connection, &prelogin).await?;
186 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
187
188 let login = Self::build_login7(config);
190 Self::send_login7(&mut connection, &login).await?;
191
192 let (server_version, current_database, routing) =
194 Self::process_login_response(&mut connection).await?;
195
196 if let Some((host, port)) = routing {
198 return Err(Error::Routing { host, port });
199 }
200
201 Ok(Client {
202 config: config.clone(),
203 _state: PhantomData,
204 connection: Some(ConnectionHandle::Tls(connection)),
205 server_version,
206 current_database: current_database.clone(),
207 statement_cache: StatementCache::with_default_size(),
208 transaction_descriptor: 0, #[cfg(feature = "otel")]
210 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
211 .with_database(current_database.unwrap_or_default()),
212 })
213 }
214
215 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
223 use bytes::BufMut;
224 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
225 use tokio::io::{AsyncReadExt, AsyncWriteExt};
226
227 tracing::debug!("using TDS 7.x flow (PreLogin first)");
228
229 let client_encryption = if config.encrypt {
232 EncryptionLevel::On
233 } else {
234 EncryptionLevel::Off
235 };
236 let prelogin = Self::build_prelogin(config, client_encryption);
237 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
238 let prelogin_bytes = prelogin.encode();
239
240 let header = PacketHeader::new(
242 PacketType::PreLogin,
243 PacketStatus::END_OF_MESSAGE,
244 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
245 );
246
247 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
248 header.encode(&mut packet_buf);
249 packet_buf.put_slice(&prelogin_bytes);
250
251 tcp_stream
252 .write_all(&packet_buf)
253 .await
254 .map_err(|e| Error::Io(Arc::new(e)))?;
255
256 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
258 tcp_stream
259 .read_exact(&mut header_buf)
260 .await
261 .map_err(|e| Error::Io(Arc::new(e)))?;
262
263 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
264 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
265
266 let mut response_buf = vec![0u8; payload_length];
267 tcp_stream
268 .read_exact(&mut response_buf)
269 .await
270 .map_err(|e| Error::Io(Arc::new(e)))?;
271
272 let prelogin_response =
273 PreLogin::decode(&response_buf[..]).map_err(|e| Error::Protocol(e.to_string()))?;
274
275 let server_encryption = prelogin_response.encryption;
277 tracing::debug!(encryption = ?server_encryption, "server encryption level");
278
279 let negotiated_encryption = match (client_encryption, server_encryption) {
285 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
286 EncryptionLevel::NotSupported
287 }
288 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
289 (EncryptionLevel::On, EncryptionLevel::Off)
290 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
291 return Err(Error::Protocol(
292 "Server does not support requested encryption level".to_string(),
293 ));
294 }
295 _ => EncryptionLevel::On,
296 };
297
298 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
301
302 if use_tls {
303 let tls_config =
306 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
307
308 let tls_connector =
309 TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
310
311 let mut tls_stream = timeout(
313 config.timeouts.tls_timeout,
314 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
315 )
316 .await
317 .map_err(|_| Error::TlsTimeout)?
318 .map_err(|e| Error::Tls(e.to_string()))?;
319
320 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
321
322 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
324
325 if login_only_encryption {
326 use tokio::io::AsyncWriteExt;
334
335 let login = Self::build_login7(config);
337 let login_payload = login.encode();
338
339 let max_packet = MAX_PACKET_SIZE;
341 let max_payload = max_packet - PACKET_HEADER_SIZE;
342 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
343 let total_chunks = chunks.len();
344
345 for (i, chunk) in chunks.into_iter().enumerate() {
346 let is_last = i == total_chunks - 1;
347 let status = if is_last {
348 PacketStatus::END_OF_MESSAGE
349 } else {
350 PacketStatus::NORMAL
351 };
352
353 let header = PacketHeader::new(
354 PacketType::Tds7Login,
355 status,
356 (PACKET_HEADER_SIZE + chunk.len()) as u16,
357 );
358
359 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
360 header.encode(&mut packet_buf);
361 packet_buf.put_slice(chunk);
362
363 tls_stream
364 .write_all(&packet_buf)
365 .await
366 .map_err(|e| Error::Io(Arc::new(e)))?;
367 }
368
369 tls_stream
371 .flush()
372 .await
373 .map_err(|e| Error::Io(Arc::new(e)))?;
374
375 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
376
377 let (wrapper, _client_conn) = tls_stream.into_inner();
381 let tcp_stream = wrapper.into_inner();
382
383 let mut connection = Connection::new(tcp_stream);
385
386 let (server_version, current_database, routing) =
388 Self::process_login_response(&mut connection).await?;
389
390 if let Some((host, port)) = routing {
392 return Err(Error::Routing { host, port });
393 }
394
395 Ok(Client {
397 config: config.clone(),
398 _state: PhantomData,
399 connection: Some(ConnectionHandle::Plain(connection)),
400 server_version,
401 current_database: current_database.clone(),
402 statement_cache: StatementCache::with_default_size(),
403 transaction_descriptor: 0, #[cfg(feature = "otel")]
405 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
406 .with_database(current_database.unwrap_or_default()),
407 })
408 } else {
409 let mut connection = Connection::new(tls_stream);
412
413 let login = Self::build_login7(config);
415 Self::send_login7(&mut connection, &login).await?;
416
417 let (server_version, current_database, routing) =
419 Self::process_login_response(&mut connection).await?;
420
421 if let Some((host, port)) = routing {
423 return Err(Error::Routing { host, port });
424 }
425
426 Ok(Client {
427 config: config.clone(),
428 _state: PhantomData,
429 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
430 server_version,
431 current_database: current_database.clone(),
432 statement_cache: StatementCache::with_default_size(),
433 transaction_descriptor: 0, #[cfg(feature = "otel")]
435 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
436 .with_database(current_database.unwrap_or_default()),
437 })
438 }
439 } else {
440 tracing::warn!(
442 "Connecting without TLS encryption. This is insecure and should only be \
443 used for development/testing on trusted networks."
444 );
445
446 let login = Self::build_login7(config);
448 let login_bytes = login.encode();
449 tracing::debug!("Login7 packet built: {} bytes", login_bytes.len(),);
450 tracing::debug!(
452 "Login7 fixed header (94 bytes): {:02X?}",
453 &login_bytes[..login_bytes.len().min(94)]
454 );
455 if login_bytes.len() > 94 {
457 tracing::debug!(
458 "Login7 variable data ({} bytes): {:02X?}",
459 login_bytes.len() - 94,
460 &login_bytes[94..]
461 );
462 }
463
464 let login_header = PacketHeader::new(
466 PacketType::Tds7Login,
467 PacketStatus::END_OF_MESSAGE,
468 (PACKET_HEADER_SIZE + login_bytes.len()) as u16,
469 )
470 .with_packet_id(1);
471 let mut login_packet_buf =
472 BytesMut::with_capacity(PACKET_HEADER_SIZE + login_bytes.len());
473 login_header.encode(&mut login_packet_buf);
474 login_packet_buf.put_slice(&login_bytes);
475
476 tracing::debug!(
477 "Sending Login7 packet: {} bytes total, header: {:02X?}",
478 login_packet_buf.len(),
479 &login_packet_buf[..PACKET_HEADER_SIZE]
480 );
481 tcp_stream
482 .write_all(&login_packet_buf)
483 .await
484 .map_err(|e| Error::Io(Arc::new(e)))?;
485 tcp_stream
486 .flush()
487 .await
488 .map_err(|e| Error::Io(Arc::new(e)))?;
489 tracing::debug!("Login7 sent and flushed over raw TCP");
490
491 let mut response_header_buf = [0u8; PACKET_HEADER_SIZE];
493 tcp_stream
494 .read_exact(&mut response_header_buf)
495 .await
496 .map_err(|e| Error::Io(Arc::new(e)))?;
497
498 let response_type = response_header_buf[0];
499 let response_length =
500 u16::from_be_bytes([response_header_buf[2], response_header_buf[3]]) as usize;
501 tracing::debug!(
502 "Response header: type={:#04X}, length={}",
503 response_type,
504 response_length
505 );
506
507 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
509 let mut response_payload = vec![0u8; payload_length];
510 tcp_stream
511 .read_exact(&mut response_payload)
512 .await
513 .map_err(|e| Error::Io(Arc::new(e)))?;
514 tracing::debug!(
515 "Response payload: {} bytes, first 32: {:02X?}",
516 response_payload.len(),
517 &response_payload[..response_payload.len().min(32)]
518 );
519
520 let connection = Connection::new(tcp_stream);
522
523 let response_bytes = bytes::Bytes::from(response_payload);
525 let mut parser = TokenParser::new(response_bytes);
526 let mut server_version = None;
527 let mut current_database = None;
528 let routing = None;
529
530 while let Some(token) = parser
531 .next_token()
532 .map_err(|e| Error::Protocol(e.to_string()))?
533 {
534 match token {
535 Token::LoginAck(ack) => {
536 tracing::info!(
537 version = ack.tds_version,
538 interface = ack.interface,
539 prog_name = %ack.prog_name,
540 "login acknowledged"
541 );
542 server_version = Some(ack.tds_version);
543 }
544 Token::EnvChange(env) => {
545 Self::process_env_change(&env, &mut current_database, &mut None);
546 }
547 Token::Error(err) => {
548 return Err(Error::Server {
549 number: err.number,
550 state: err.state,
551 class: err.class,
552 message: err.message.clone(),
553 server: if err.server.is_empty() {
554 None
555 } else {
556 Some(err.server.clone())
557 },
558 procedure: if err.procedure.is_empty() {
559 None
560 } else {
561 Some(err.procedure.clone())
562 },
563 line: err.line as u32,
564 });
565 }
566 Token::Info(info) => {
567 tracing::info!(
568 number = info.number,
569 message = %info.message,
570 "server info message"
571 );
572 }
573 Token::Done(done) => {
574 if done.status.error {
575 return Err(Error::Protocol("login failed".to_string()));
576 }
577 break;
578 }
579 _ => {}
580 }
581 }
582
583 if let Some((host, port)) = routing {
585 return Err(Error::Routing { host, port });
586 }
587
588 Ok(Client {
589 config: config.clone(),
590 _state: PhantomData,
591 connection: Some(ConnectionHandle::Plain(connection)),
592 server_version,
593 current_database: current_database.clone(),
594 statement_cache: StatementCache::with_default_size(),
595 transaction_descriptor: 0, #[cfg(feature = "otel")]
597 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
598 .with_database(current_database.unwrap_or_default()),
599 })
600 }
601 }
602
603 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
605 let mut prelogin = PreLogin::new().with_encryption(encryption);
606
607 if config.mars {
608 prelogin = prelogin.with_mars(true);
609 }
610
611 if let Some(ref instance) = config.instance {
612 prelogin = prelogin.with_instance(instance);
613 }
614
615 prelogin
616 }
617
618 fn build_login7(config: &Config) -> Login7 {
620 let mut login = Login7::new()
621 .with_packet_size(config.packet_size as u32)
622 .with_app_name(&config.application_name)
623 .with_server_name(&config.host)
624 .with_hostname(&config.host);
625
626 if let Some(ref database) = config.database {
627 login = login.with_database(database);
628 }
629
630 match &config.credentials {
632 mssql_auth::Credentials::SqlServer { username, password } => {
633 login = login.with_sql_auth(username.as_ref(), password.as_ref());
634 }
635 _ => {}
637 }
638
639 login
640 }
641
642 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
644 where
645 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
646 {
647 let payload = prelogin.encode();
648 let max_packet = MAX_PACKET_SIZE;
649
650 connection
651 .send_message(PacketType::PreLogin, payload, max_packet)
652 .await
653 .map_err(|e| Error::Protocol(e.to_string()))
654 }
655
656 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
658 where
659 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
660 {
661 let message = connection
662 .read_message()
663 .await
664 .map_err(|e| Error::Protocol(e.to_string()))?
665 .ok_or(Error::ConnectionClosed)?;
666
667 PreLogin::decode(&message.payload[..]).map_err(|e| Error::Protocol(e.to_string()))
668 }
669
670 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
672 where
673 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
674 {
675 let payload = login.encode();
676 let max_packet = MAX_PACKET_SIZE;
677
678 connection
679 .send_message(PacketType::Tds7Login, payload, max_packet)
680 .await
681 .map_err(|e| Error::Protocol(e.to_string()))
682 }
683
684 async fn process_login_response<T>(
688 connection: &mut Connection<T>,
689 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
690 where
691 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
692 {
693 let message = connection
694 .read_message()
695 .await
696 .map_err(|e| Error::Protocol(e.to_string()))?
697 .ok_or(Error::ConnectionClosed)?;
698
699 let response_bytes = message.payload;
700
701 let mut parser = TokenParser::new(response_bytes);
702 let mut server_version = None;
703 let mut database = None;
704 let mut routing = None;
705
706 while let Some(token) = parser
707 .next_token()
708 .map_err(|e| Error::Protocol(e.to_string()))?
709 {
710 match token {
711 Token::LoginAck(ack) => {
712 tracing::info!(
713 version = ack.tds_version,
714 interface = ack.interface,
715 prog_name = %ack.prog_name,
716 "login acknowledged"
717 );
718 server_version = Some(ack.tds_version);
719 }
720 Token::EnvChange(env) => {
721 Self::process_env_change(&env, &mut database, &mut routing);
722 }
723 Token::Error(err) => {
724 return Err(Error::Server {
725 number: err.number,
726 state: err.state,
727 class: err.class,
728 message: err.message.clone(),
729 server: if err.server.is_empty() {
730 None
731 } else {
732 Some(err.server.clone())
733 },
734 procedure: if err.procedure.is_empty() {
735 None
736 } else {
737 Some(err.procedure.clone())
738 },
739 line: err.line as u32,
740 });
741 }
742 Token::Info(info) => {
743 tracing::info!(
744 number = info.number,
745 message = %info.message,
746 "server info message"
747 );
748 }
749 Token::Done(done) => {
750 if done.status.error {
751 return Err(Error::Protocol("login failed".to_string()));
752 }
753 break;
754 }
755 _ => {}
756 }
757 }
758
759 Ok((server_version, database, routing))
760 }
761
762 fn process_env_change(
764 env: &EnvChange,
765 database: &mut Option<String>,
766 routing: &mut Option<(String, u16)>,
767 ) {
768 use tds_protocol::token::EnvChangeValue;
769
770 match env.env_type {
771 EnvChangeType::Database => {
772 if let EnvChangeValue::String(ref new_value) = env.new_value {
773 tracing::debug!(database = %new_value, "database changed");
774 *database = Some(new_value.clone());
775 }
776 }
777 EnvChangeType::Routing => {
778 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
779 tracing::info!(host = %host, port = port, "routing redirect received");
780 *routing = Some((host.clone(), port));
781 }
782 }
783 _ => {
784 if let EnvChangeValue::String(ref new_value) = env.new_value {
785 tracing::debug!(
786 env_type = ?env.env_type,
787 new_value = %new_value,
788 "environment change"
789 );
790 }
791 }
792 }
793 }
794}
795
796impl<S: ConnectionState> Client<S> {
798 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
804 let payload =
805 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
806 let max_packet = self.config.packet_size as usize;
807
808 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
809
810 match connection {
811 ConnectionHandle::Tls(conn) => {
812 conn.send_message(PacketType::SqlBatch, payload, max_packet)
813 .await
814 .map_err(|e| Error::Protocol(e.to_string()))?;
815 }
816 ConnectionHandle::TlsPrelogin(conn) => {
817 conn.send_message(PacketType::SqlBatch, payload, max_packet)
818 .await
819 .map_err(|e| Error::Protocol(e.to_string()))?;
820 }
821 ConnectionHandle::Plain(conn) => {
822 conn.send_message(PacketType::SqlBatch, payload, max_packet)
823 .await
824 .map_err(|e| Error::Protocol(e.to_string()))?;
825 }
826 }
827
828 Ok(())
829 }
830
831 async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
835 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
836 let max_packet = self.config.packet_size as usize;
837
838 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
839
840 match connection {
841 ConnectionHandle::Tls(conn) => {
842 conn.send_message(PacketType::Rpc, payload, max_packet)
843 .await
844 .map_err(|e| Error::Protocol(e.to_string()))?;
845 }
846 ConnectionHandle::TlsPrelogin(conn) => {
847 conn.send_message(PacketType::Rpc, payload, max_packet)
848 .await
849 .map_err(|e| Error::Protocol(e.to_string()))?;
850 }
851 ConnectionHandle::Plain(conn) => {
852 conn.send_message(PacketType::Rpc, payload, max_packet)
853 .await
854 .map_err(|e| Error::Protocol(e.to_string()))?;
855 }
856 }
857
858 Ok(())
859 }
860
861 fn convert_params(params: &[&(dyn crate::ToSql + Sync)]) -> Result<Vec<RpcParam>> {
863 use bytes::{BufMut, BytesMut};
864 use mssql_types::SqlValue;
865
866 params
867 .iter()
868 .enumerate()
869 .map(|(i, p)| {
870 let sql_value = p.to_sql()?;
871 let name = format!("@p{}", i + 1);
872
873 Ok(match sql_value {
874 SqlValue::Null => RpcParam::null(&name, RpcTypeInfo::nvarchar(1)),
875 SqlValue::Bool(v) => {
876 let mut buf = BytesMut::with_capacity(1);
877 buf.put_u8(if v { 1 } else { 0 });
878 RpcParam::new(&name, RpcTypeInfo::bit(), buf.freeze())
879 }
880 SqlValue::TinyInt(v) => {
881 let mut buf = BytesMut::with_capacity(1);
882 buf.put_u8(v);
883 RpcParam::new(&name, RpcTypeInfo::tinyint(), buf.freeze())
884 }
885 SqlValue::SmallInt(v) => {
886 let mut buf = BytesMut::with_capacity(2);
887 buf.put_i16_le(v);
888 RpcParam::new(&name, RpcTypeInfo::smallint(), buf.freeze())
889 }
890 SqlValue::Int(v) => RpcParam::int(&name, v),
891 SqlValue::BigInt(v) => RpcParam::bigint(&name, v),
892 SqlValue::Float(v) => {
893 let mut buf = BytesMut::with_capacity(4);
894 buf.put_f32_le(v);
895 RpcParam::new(&name, RpcTypeInfo::real(), buf.freeze())
896 }
897 SqlValue::Double(v) => {
898 let mut buf = BytesMut::with_capacity(8);
899 buf.put_f64_le(v);
900 RpcParam::new(&name, RpcTypeInfo::float(), buf.freeze())
901 }
902 SqlValue::String(ref s) => RpcParam::nvarchar(&name, s),
903 SqlValue::Binary(ref b) => {
904 RpcParam::new(&name, RpcTypeInfo::varbinary(b.len() as u16), b.clone())
905 }
906 SqlValue::Xml(ref s) => RpcParam::nvarchar(&name, s),
907 #[cfg(feature = "uuid")]
908 SqlValue::Uuid(u) => {
909 let bytes = u.as_bytes();
911 let mut buf = BytesMut::with_capacity(16);
912 buf.put_u32_le(u32::from_be_bytes([
914 bytes[0], bytes[1], bytes[2], bytes[3],
915 ]));
916 buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
917 buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
918 buf.put_slice(&bytes[8..16]);
919 RpcParam::new(&name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
920 }
921 #[cfg(feature = "decimal")]
922 SqlValue::Decimal(d) => {
923 RpcParam::nvarchar(&name, &d.to_string())
925 }
926 #[cfg(feature = "chrono")]
927 SqlValue::Date(_)
928 | SqlValue::Time(_)
929 | SqlValue::DateTime(_)
930 | SqlValue::DateTimeOffset(_) => {
931 let s = match &sql_value {
934 SqlValue::Date(d) => d.to_string(),
935 SqlValue::Time(t) => t.to_string(),
936 SqlValue::DateTime(dt) => dt.to_string(),
937 SqlValue::DateTimeOffset(dto) => dto.to_rfc3339(),
938 _ => unreachable!(),
939 };
940 RpcParam::nvarchar(&name, &s)
941 }
942 #[cfg(feature = "json")]
943 SqlValue::Json(ref j) => RpcParam::nvarchar(&name, &j.to_string()),
944 SqlValue::Tvp(ref tvp_data) => {
945 Self::encode_tvp_param(&name, tvp_data)?
947 }
948 _ => {
950 return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
951 from: sql_value.type_name().to_string(),
952 to: "RPC parameter",
953 }));
954 }
955 })
956 })
957 .collect()
958 }
959
960 fn encode_tvp_param(name: &str, tvp_data: &mssql_types::TvpData) -> Result<RpcParam> {
965 let wire_columns: Vec<TvpWireColumnDef> = tvp_data
967 .columns
968 .iter()
969 .map(|col| {
970 let wire_type = Self::convert_tvp_column_type(&col.column_type);
971 TvpWireColumnDef {
972 wire_type,
973 flags: TvpColumnFlags {
974 nullable: col.nullable,
975 },
976 }
977 })
978 .collect();
979
980 let encoder = TvpEncoder::new(&tvp_data.schema, &tvp_data.type_name, &wire_columns);
982
983 let mut buf = BytesMut::with_capacity(256);
985
986 encoder.encode_metadata(&mut buf);
988
989 for row in &tvp_data.rows {
991 encoder.encode_row(&mut buf, |row_buf| {
992 for (col_idx, value) in row.iter().enumerate() {
993 let wire_type = &wire_columns[col_idx].wire_type;
994 Self::encode_tvp_value(value, wire_type, row_buf);
995 }
996 });
997 }
998
999 encoder.encode_end(&mut buf);
1001
1002 let type_info = RpcTypeInfo {
1006 type_id: 0xF3, max_length: None,
1008 precision: None,
1009 scale: None,
1010 collation: None,
1011 };
1012
1013 Ok(RpcParam {
1014 name: name.to_string(),
1015 flags: tds_protocol::rpc::ParamFlags::default(),
1016 type_info,
1017 value: Some(buf.freeze()),
1018 })
1019 }
1020
1021 fn convert_tvp_column_type(col_type: &mssql_types::TvpColumnType) -> TvpWireType {
1023 match col_type {
1024 mssql_types::TvpColumnType::Bit => TvpWireType::Bit,
1025 mssql_types::TvpColumnType::TinyInt => TvpWireType::Int { size: 1 },
1026 mssql_types::TvpColumnType::SmallInt => TvpWireType::Int { size: 2 },
1027 mssql_types::TvpColumnType::Int => TvpWireType::Int { size: 4 },
1028 mssql_types::TvpColumnType::BigInt => TvpWireType::Int { size: 8 },
1029 mssql_types::TvpColumnType::Real => TvpWireType::Float { size: 4 },
1030 mssql_types::TvpColumnType::Float => TvpWireType::Float { size: 8 },
1031 mssql_types::TvpColumnType::Decimal { precision, scale } => TvpWireType::Decimal {
1032 precision: *precision,
1033 scale: *scale,
1034 },
1035 mssql_types::TvpColumnType::NVarChar { max_length } => TvpWireType::NVarChar {
1036 max_length: *max_length,
1037 },
1038 mssql_types::TvpColumnType::VarChar { max_length } => TvpWireType::VarChar {
1039 max_length: *max_length,
1040 },
1041 mssql_types::TvpColumnType::VarBinary { max_length } => TvpWireType::VarBinary {
1042 max_length: *max_length,
1043 },
1044 mssql_types::TvpColumnType::UniqueIdentifier => TvpWireType::Guid,
1045 mssql_types::TvpColumnType::Date => TvpWireType::Date,
1046 mssql_types::TvpColumnType::Time { scale } => TvpWireType::Time { scale: *scale },
1047 mssql_types::TvpColumnType::DateTime2 { scale } => {
1048 TvpWireType::DateTime2 { scale: *scale }
1049 }
1050 mssql_types::TvpColumnType::DateTimeOffset { scale } => {
1051 TvpWireType::DateTimeOffset { scale: *scale }
1052 }
1053 mssql_types::TvpColumnType::Xml => TvpWireType::Xml,
1054 }
1055 }
1056
1057 fn encode_tvp_value(
1059 value: &mssql_types::SqlValue,
1060 wire_type: &TvpWireType,
1061 buf: &mut BytesMut,
1062 ) {
1063 use mssql_types::SqlValue;
1064
1065 match value {
1066 SqlValue::Null => {
1067 encode_tvp_null(wire_type, buf);
1068 }
1069 SqlValue::Bool(v) => {
1070 encode_tvp_bit(*v, buf);
1071 }
1072 SqlValue::TinyInt(v) => {
1073 encode_tvp_int(*v as i64, 1, buf);
1074 }
1075 SqlValue::SmallInt(v) => {
1076 encode_tvp_int(*v as i64, 2, buf);
1077 }
1078 SqlValue::Int(v) => {
1079 encode_tvp_int(*v as i64, 4, buf);
1080 }
1081 SqlValue::BigInt(v) => {
1082 encode_tvp_int(*v, 8, buf);
1083 }
1084 SqlValue::Float(v) => {
1085 encode_tvp_float(*v as f64, 4, buf);
1086 }
1087 SqlValue::Double(v) => {
1088 encode_tvp_float(*v, 8, buf);
1089 }
1090 SqlValue::String(s) => {
1091 let max_len = match wire_type {
1092 TvpWireType::NVarChar { max_length } => *max_length,
1093 _ => 4000,
1094 };
1095 encode_tvp_nvarchar(s, max_len, buf);
1096 }
1097 SqlValue::Binary(b) => {
1098 let max_len = match wire_type {
1099 TvpWireType::VarBinary { max_length } => *max_length,
1100 _ => 8000,
1101 };
1102 encode_tvp_varbinary(b, max_len, buf);
1103 }
1104 #[cfg(feature = "decimal")]
1105 SqlValue::Decimal(d) => {
1106 let sign = if d.is_sign_negative() { 0u8 } else { 1u8 };
1107 let mantissa = d.mantissa().unsigned_abs();
1108 encode_tvp_decimal(sign, mantissa, buf);
1109 }
1110 #[cfg(feature = "uuid")]
1111 SqlValue::Uuid(u) => {
1112 let bytes = u.as_bytes();
1113 tds_protocol::tvp::encode_tvp_guid(bytes, buf);
1114 }
1115 #[cfg(feature = "chrono")]
1116 SqlValue::Date(d) => {
1117 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1119 let days = d.signed_duration_since(base).num_days() as u32;
1120 tds_protocol::tvp::encode_tvp_date(days, buf);
1121 }
1122 #[cfg(feature = "chrono")]
1123 SqlValue::Time(t) => {
1124 use chrono::Timelike;
1125 let nanos =
1126 t.num_seconds_from_midnight() as u64 * 1_000_000_000 + t.nanosecond() as u64;
1127 let intervals = nanos / 100;
1128 let scale = match wire_type {
1129 TvpWireType::Time { scale } => *scale,
1130 _ => 7,
1131 };
1132 tds_protocol::tvp::encode_tvp_time(intervals, scale, buf);
1133 }
1134 #[cfg(feature = "chrono")]
1135 SqlValue::DateTime(dt) => {
1136 use chrono::Timelike;
1137 let nanos = dt.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1139 + dt.time().nanosecond() as u64;
1140 let intervals = nanos / 100;
1141 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1143 let days = dt.date().signed_duration_since(base).num_days() as u32;
1144 let scale = match wire_type {
1145 TvpWireType::DateTime2 { scale } => *scale,
1146 _ => 7,
1147 };
1148 tds_protocol::tvp::encode_tvp_datetime2(intervals, days, scale, buf);
1149 }
1150 #[cfg(feature = "chrono")]
1151 SqlValue::DateTimeOffset(dto) => {
1152 use chrono::{Offset, Timelike};
1153 let nanos = dto.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1155 + dto.time().nanosecond() as u64;
1156 let intervals = nanos / 100;
1157 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1159 let days = dto.date_naive().signed_duration_since(base).num_days() as u32;
1160 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1162 let scale = match wire_type {
1163 TvpWireType::DateTimeOffset { scale } => *scale,
1164 _ => 7,
1165 };
1166 tds_protocol::tvp::encode_tvp_datetimeoffset(
1167 intervals,
1168 days,
1169 offset_minutes,
1170 scale,
1171 buf,
1172 );
1173 }
1174 #[cfg(feature = "json")]
1175 SqlValue::Json(j) => {
1176 encode_tvp_nvarchar(&j.to_string(), 0xFFFF, buf);
1178 }
1179 SqlValue::Xml(s) => {
1180 encode_tvp_nvarchar(s, 0xFFFF, buf);
1182 }
1183 SqlValue::Tvp(_) => {
1184 encode_tvp_null(wire_type, buf);
1186 }
1187 _ => {
1189 encode_tvp_null(wire_type, buf);
1190 }
1191 }
1192 }
1193
1194 async fn read_query_response(
1196 &mut self,
1197 ) -> Result<(Vec<crate::row::Column>, Vec<crate::row::Row>)> {
1198 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1199
1200 let message = match connection {
1201 ConnectionHandle::Tls(conn) => conn
1202 .read_message()
1203 .await
1204 .map_err(|e| Error::Protocol(e.to_string()))?,
1205 ConnectionHandle::TlsPrelogin(conn) => conn
1206 .read_message()
1207 .await
1208 .map_err(|e| Error::Protocol(e.to_string()))?,
1209 ConnectionHandle::Plain(conn) => conn
1210 .read_message()
1211 .await
1212 .map_err(|e| Error::Protocol(e.to_string()))?,
1213 }
1214 .ok_or(Error::ConnectionClosed)?;
1215
1216 let mut parser = TokenParser::new(message.payload);
1217 let mut columns: Vec<crate::row::Column> = Vec::new();
1218 let mut rows: Vec<crate::row::Row> = Vec::new();
1219 let mut protocol_metadata: Option<ColMetaData> = None;
1220
1221 loop {
1222 let token = parser
1224 .next_token_with_metadata(protocol_metadata.as_ref())
1225 .map_err(|e| Error::Protocol(e.to_string()))?;
1226
1227 let Some(token) = token else {
1228 break;
1229 };
1230
1231 match token {
1232 Token::ColMetaData(meta) => {
1233 columns = meta
1234 .columns
1235 .iter()
1236 .enumerate()
1237 .map(|(i, col)| {
1238 let type_name = format!("{:?}", col.type_id);
1239 let mut column = crate::row::Column::new(&col.name, i, type_name)
1240 .with_nullable(col.flags & 0x01 != 0);
1241
1242 if let Some(max_len) = col.type_info.max_length {
1243 column = column.with_max_length(max_len);
1244 }
1245 if let (Some(prec), Some(scale)) =
1246 (col.type_info.precision, col.type_info.scale)
1247 {
1248 column = column.with_precision_scale(prec, scale);
1249 }
1250 column
1251 })
1252 .collect();
1253
1254 tracing::debug!(columns = columns.len(), "received column metadata");
1255 protocol_metadata = Some(meta);
1256 }
1257 Token::Row(raw_row) => {
1258 if let Some(ref meta) = protocol_metadata {
1259 let row = Self::convert_raw_row(&raw_row, meta, &columns)?;
1260 rows.push(row);
1261 }
1262 }
1263 Token::NbcRow(nbc_row) => {
1264 if let Some(ref meta) = protocol_metadata {
1265 let row = Self::convert_nbc_row(&nbc_row, meta, &columns)?;
1266 rows.push(row);
1267 }
1268 }
1269 Token::Error(err) => {
1270 return Err(Error::Server {
1271 number: err.number,
1272 state: err.state,
1273 class: err.class,
1274 message: err.message.clone(),
1275 server: if err.server.is_empty() {
1276 None
1277 } else {
1278 Some(err.server.clone())
1279 },
1280 procedure: if err.procedure.is_empty() {
1281 None
1282 } else {
1283 Some(err.procedure.clone())
1284 },
1285 line: err.line as u32,
1286 });
1287 }
1288 Token::Done(done) => {
1289 if done.status.error {
1290 return Err(Error::Query("query failed".to_string()));
1291 }
1292 tracing::debug!(
1293 row_count = done.row_count,
1294 has_more = done.status.more,
1295 "query complete"
1296 );
1297 break;
1298 }
1299 Token::DoneProc(done) => {
1300 if done.status.error {
1301 return Err(Error::Query("query failed".to_string()));
1302 }
1303 }
1304 Token::DoneInProc(done) => {
1305 if done.status.error {
1306 return Err(Error::Query("query failed".to_string()));
1307 }
1308 }
1309 Token::Info(info) => {
1310 tracing::debug!(
1311 number = info.number,
1312 message = %info.message,
1313 "server info message"
1314 );
1315 }
1316 _ => {}
1317 }
1318 }
1319
1320 tracing::debug!(
1321 columns = columns.len(),
1322 rows = rows.len(),
1323 "query response parsed"
1324 );
1325 Ok((columns, rows))
1326 }
1327
1328 fn convert_raw_row(
1332 raw: &RawRow,
1333 meta: &ColMetaData,
1334 columns: &[crate::row::Column],
1335 ) -> Result<crate::row::Row> {
1336 let mut values = Vec::with_capacity(meta.columns.len());
1337 let mut buf = raw.data.as_ref();
1338
1339 for col in &meta.columns {
1340 let value = Self::parse_column_value(&mut buf, col)?;
1341 values.push(value);
1342 }
1343
1344 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1345 }
1346
1347 fn convert_nbc_row(
1351 nbc: &NbcRow,
1352 meta: &ColMetaData,
1353 columns: &[crate::row::Column],
1354 ) -> Result<crate::row::Row> {
1355 let mut values = Vec::with_capacity(meta.columns.len());
1356 let mut buf = nbc.data.as_ref();
1357
1358 for (i, col) in meta.columns.iter().enumerate() {
1359 if nbc.is_null(i) {
1360 values.push(mssql_types::SqlValue::Null);
1361 } else {
1362 let value = Self::parse_column_value(&mut buf, col)?;
1363 values.push(value);
1364 }
1365 }
1366
1367 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1368 }
1369
1370 fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<mssql_types::SqlValue> {
1372 use bytes::Buf;
1373 use mssql_types::SqlValue;
1374 use tds_protocol::types::TypeId;
1375
1376 let value = match col.type_id {
1377 TypeId::Null => SqlValue::Null,
1379
1380 TypeId::Int1 => {
1382 if buf.remaining() < 1 {
1383 return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
1384 }
1385 SqlValue::TinyInt(buf.get_u8())
1386 }
1387 TypeId::Bit => {
1388 if buf.remaining() < 1 {
1389 return Err(Error::Protocol("unexpected EOF reading BIT".into()));
1390 }
1391 SqlValue::Bool(buf.get_u8() != 0)
1392 }
1393
1394 TypeId::Int2 => {
1396 if buf.remaining() < 2 {
1397 return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
1398 }
1399 SqlValue::SmallInt(buf.get_i16_le())
1400 }
1401
1402 TypeId::Int4 => {
1404 if buf.remaining() < 4 {
1405 return Err(Error::Protocol("unexpected EOF reading INT".into()));
1406 }
1407 SqlValue::Int(buf.get_i32_le())
1408 }
1409 TypeId::Float4 => {
1410 if buf.remaining() < 4 {
1411 return Err(Error::Protocol("unexpected EOF reading REAL".into()));
1412 }
1413 SqlValue::Float(buf.get_f32_le())
1414 }
1415
1416 TypeId::Int8 => {
1418 if buf.remaining() < 8 {
1419 return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
1420 }
1421 SqlValue::BigInt(buf.get_i64_le())
1422 }
1423 TypeId::Float8 => {
1424 if buf.remaining() < 8 {
1425 return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
1426 }
1427 SqlValue::Double(buf.get_f64_le())
1428 }
1429 TypeId::Money => {
1430 if buf.remaining() < 8 {
1431 return Err(Error::Protocol("unexpected EOF reading MONEY".into()));
1432 }
1433 let high = buf.get_i32_le();
1435 let low = buf.get_u32_le();
1436 let cents = ((high as i64) << 32) | (low as i64);
1437 let value = (cents as f64) / 10000.0;
1438 SqlValue::Double(value)
1439 }
1440 TypeId::Money4 => {
1441 if buf.remaining() < 4 {
1442 return Err(Error::Protocol("unexpected EOF reading SMALLMONEY".into()));
1443 }
1444 let cents = buf.get_i32_le();
1445 let value = (cents as f64) / 10000.0;
1446 SqlValue::Double(value)
1447 }
1448
1449 TypeId::IntN => {
1451 if buf.remaining() < 1 {
1452 return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
1453 }
1454 let len = buf.get_u8();
1455 match len {
1456 0 => SqlValue::Null,
1457 1 => SqlValue::TinyInt(buf.get_u8()),
1458 2 => SqlValue::SmallInt(buf.get_i16_le()),
1459 4 => SqlValue::Int(buf.get_i32_le()),
1460 8 => SqlValue::BigInt(buf.get_i64_le()),
1461 _ => {
1462 return Err(Error::Protocol(format!("invalid IntN length: {len}")));
1463 }
1464 }
1465 }
1466 TypeId::FloatN => {
1467 if buf.remaining() < 1 {
1468 return Err(Error::Protocol(
1469 "unexpected EOF reading FloatN length".into(),
1470 ));
1471 }
1472 let len = buf.get_u8();
1473 match len {
1474 0 => SqlValue::Null,
1475 4 => SqlValue::Float(buf.get_f32_le()),
1476 8 => SqlValue::Double(buf.get_f64_le()),
1477 _ => {
1478 return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
1479 }
1480 }
1481 }
1482 TypeId::BitN => {
1483 if buf.remaining() < 1 {
1484 return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
1485 }
1486 let len = buf.get_u8();
1487 match len {
1488 0 => SqlValue::Null,
1489 1 => SqlValue::Bool(buf.get_u8() != 0),
1490 _ => {
1491 return Err(Error::Protocol(format!("invalid BitN length: {len}")));
1492 }
1493 }
1494 }
1495 TypeId::MoneyN => {
1496 if buf.remaining() < 1 {
1497 return Err(Error::Protocol(
1498 "unexpected EOF reading MoneyN length".into(),
1499 ));
1500 }
1501 let len = buf.get_u8();
1502 match len {
1503 0 => SqlValue::Null,
1504 4 => {
1505 let cents = buf.get_i32_le();
1506 SqlValue::Double((cents as f64) / 10000.0)
1507 }
1508 8 => {
1509 let high = buf.get_i32_le();
1510 let low = buf.get_u32_le();
1511 let cents = ((high as i64) << 32) | (low as i64);
1512 SqlValue::Double((cents as f64) / 10000.0)
1513 }
1514 _ => {
1515 return Err(Error::Protocol(format!("invalid MoneyN length: {len}")));
1516 }
1517 }
1518 }
1519 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1521 if buf.remaining() < 1 {
1522 return Err(Error::Protocol(
1523 "unexpected EOF reading DECIMAL/NUMERIC length".into(),
1524 ));
1525 }
1526 let len = buf.get_u8() as usize;
1527 if len == 0 {
1528 SqlValue::Null
1529 } else {
1530 if buf.remaining() < len {
1531 return Err(Error::Protocol(
1532 "unexpected EOF reading DECIMAL/NUMERIC data".into(),
1533 ));
1534 }
1535
1536 let sign = buf.get_u8();
1538 let mantissa_len = len - 1;
1539
1540 let mut mantissa_bytes = [0u8; 16];
1542 for i in 0..mantissa_len.min(16) {
1543 mantissa_bytes[i] = buf.get_u8();
1544 }
1545 for _ in 16..mantissa_len {
1547 buf.get_u8();
1548 }
1549
1550 let mantissa = u128::from_le_bytes(mantissa_bytes);
1551 let scale = col.type_info.scale.unwrap_or(0) as u32;
1552
1553 #[cfg(feature = "decimal")]
1554 {
1555 use rust_decimal::Decimal;
1556 let mut decimal = Decimal::from_i128_with_scale(mantissa as i128, scale);
1557 if sign == 0 {
1558 decimal.set_sign_negative(true);
1559 }
1560 SqlValue::Decimal(decimal)
1561 }
1562
1563 #[cfg(not(feature = "decimal"))]
1564 {
1565 let divisor = 10f64.powi(scale as i32);
1567 let value = (mantissa as f64) / divisor;
1568 let value = if sign == 0 { -value } else { value };
1569 SqlValue::Double(value)
1570 }
1571 }
1572 }
1573
1574 TypeId::DateTimeN => {
1576 if buf.remaining() < 1 {
1577 return Err(Error::Protocol(
1578 "unexpected EOF reading DateTimeN length".into(),
1579 ));
1580 }
1581 let len = buf.get_u8() as usize;
1582 if len == 0 {
1583 SqlValue::Null
1584 } else if buf.remaining() < len {
1585 return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
1586 } else {
1587 match len {
1588 4 => {
1589 let days = buf.get_u16_le() as i64;
1591 let minutes = buf.get_u16_le() as u32;
1592 #[cfg(feature = "chrono")]
1593 {
1594 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1595 let date = base + chrono::Duration::days(days);
1596 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1597 minutes * 60,
1598 0,
1599 )
1600 .unwrap();
1601 SqlValue::DateTime(date.and_time(time))
1602 }
1603 #[cfg(not(feature = "chrono"))]
1604 {
1605 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1606 }
1607 }
1608 8 => {
1609 let days = buf.get_i32_le() as i64;
1611 let time_300ths = buf.get_u32_le() as u64;
1612 #[cfg(feature = "chrono")]
1613 {
1614 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1615 let date = base + chrono::Duration::days(days);
1616 let total_ms = (time_300ths * 1000) / 300;
1618 let secs = (total_ms / 1000) as u32;
1619 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1620 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1621 secs, nanos,
1622 )
1623 .unwrap();
1624 SqlValue::DateTime(date.and_time(time))
1625 }
1626 #[cfg(not(feature = "chrono"))]
1627 {
1628 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1629 }
1630 }
1631 _ => {
1632 return Err(Error::Protocol(format!(
1633 "invalid DateTimeN length: {len}"
1634 )));
1635 }
1636 }
1637 }
1638 }
1639
1640 TypeId::DateTime => {
1642 if buf.remaining() < 8 {
1643 return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
1644 }
1645 let days = buf.get_i32_le() as i64;
1646 let time_300ths = buf.get_u32_le() as u64;
1647 #[cfg(feature = "chrono")]
1648 {
1649 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1650 let date = base + chrono::Duration::days(days);
1651 let total_ms = (time_300ths * 1000) / 300;
1652 let secs = (total_ms / 1000) as u32;
1653 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1654 let time =
1655 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).unwrap();
1656 SqlValue::DateTime(date.and_time(time))
1657 }
1658 #[cfg(not(feature = "chrono"))]
1659 {
1660 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1661 }
1662 }
1663
1664 TypeId::DateTime4 => {
1666 if buf.remaining() < 4 {
1667 return Err(Error::Protocol(
1668 "unexpected EOF reading SMALLDATETIME".into(),
1669 ));
1670 }
1671 let days = buf.get_u16_le() as i64;
1672 let minutes = buf.get_u16_le() as u32;
1673 #[cfg(feature = "chrono")]
1674 {
1675 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1676 let date = base + chrono::Duration::days(days);
1677 let time =
1678 chrono::NaiveTime::from_num_seconds_from_midnight_opt(minutes * 60, 0)
1679 .unwrap();
1680 SqlValue::DateTime(date.and_time(time))
1681 }
1682 #[cfg(not(feature = "chrono"))]
1683 {
1684 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1685 }
1686 }
1687
1688 TypeId::Date => {
1690 if buf.remaining() < 1 {
1691 return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
1692 }
1693 let len = buf.get_u8() as usize;
1694 if len == 0 {
1695 SqlValue::Null
1696 } else if len != 3 {
1697 return Err(Error::Protocol(format!("invalid DATE length: {len}")));
1698 } else if buf.remaining() < 3 {
1699 return Err(Error::Protocol("unexpected EOF reading DATE".into()));
1700 } else {
1701 let days = buf.get_u8() as u32
1703 | ((buf.get_u8() as u32) << 8)
1704 | ((buf.get_u8() as u32) << 16);
1705 #[cfg(feature = "chrono")]
1706 {
1707 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1708 let date = base + chrono::Duration::days(days as i64);
1709 SqlValue::Date(date)
1710 }
1711 #[cfg(not(feature = "chrono"))]
1712 {
1713 SqlValue::String(format!("DATE({days})"))
1714 }
1715 }
1716 }
1717
1718 TypeId::Time => {
1720 if buf.remaining() < 1 {
1721 return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
1722 }
1723 let len = buf.get_u8() as usize;
1724 if len == 0 {
1725 SqlValue::Null
1726 } else if buf.remaining() < len {
1727 return Err(Error::Protocol("unexpected EOF reading TIME".into()));
1728 } else {
1729 let scale = col.type_info.scale.unwrap_or(7);
1730 let mut time_bytes = [0u8; 8];
1731 for byte in time_bytes.iter_mut().take(len) {
1732 *byte = buf.get_u8();
1733 }
1734 let intervals = u64::from_le_bytes(time_bytes);
1735 #[cfg(feature = "chrono")]
1736 {
1737 let time = Self::intervals_to_time(intervals, scale);
1738 SqlValue::Time(time)
1739 }
1740 #[cfg(not(feature = "chrono"))]
1741 {
1742 SqlValue::String(format!("TIME({intervals})"))
1743 }
1744 }
1745 }
1746
1747 TypeId::DateTime2 => {
1749 if buf.remaining() < 1 {
1750 return Err(Error::Protocol(
1751 "unexpected EOF reading DATETIME2 length".into(),
1752 ));
1753 }
1754 let len = buf.get_u8() as usize;
1755 if len == 0 {
1756 SqlValue::Null
1757 } else if buf.remaining() < len {
1758 return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
1759 } else {
1760 let scale = col.type_info.scale.unwrap_or(7);
1761 let time_len = Self::time_bytes_for_scale(scale);
1762
1763 let mut time_bytes = [0u8; 8];
1765 for byte in time_bytes.iter_mut().take(time_len) {
1766 *byte = buf.get_u8();
1767 }
1768 let intervals = u64::from_le_bytes(time_bytes);
1769
1770 let days = buf.get_u8() as u32
1772 | ((buf.get_u8() as u32) << 8)
1773 | ((buf.get_u8() as u32) << 16);
1774
1775 #[cfg(feature = "chrono")]
1776 {
1777 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1778 let date = base + chrono::Duration::days(days as i64);
1779 let time = Self::intervals_to_time(intervals, scale);
1780 SqlValue::DateTime(date.and_time(time))
1781 }
1782 #[cfg(not(feature = "chrono"))]
1783 {
1784 SqlValue::String(format!("DATETIME2({days},{intervals})"))
1785 }
1786 }
1787 }
1788
1789 TypeId::DateTimeOffset => {
1791 if buf.remaining() < 1 {
1792 return Err(Error::Protocol(
1793 "unexpected EOF reading DATETIMEOFFSET length".into(),
1794 ));
1795 }
1796 let len = buf.get_u8() as usize;
1797 if len == 0 {
1798 SqlValue::Null
1799 } else if buf.remaining() < len {
1800 return Err(Error::Protocol(
1801 "unexpected EOF reading DATETIMEOFFSET".into(),
1802 ));
1803 } else {
1804 let scale = col.type_info.scale.unwrap_or(7);
1805 let time_len = Self::time_bytes_for_scale(scale);
1806
1807 let mut time_bytes = [0u8; 8];
1809 for byte in time_bytes.iter_mut().take(time_len) {
1810 *byte = buf.get_u8();
1811 }
1812 let intervals = u64::from_le_bytes(time_bytes);
1813
1814 let days = buf.get_u8() as u32
1816 | ((buf.get_u8() as u32) << 8)
1817 | ((buf.get_u8() as u32) << 16);
1818
1819 let offset_minutes = buf.get_i16_le();
1821
1822 #[cfg(feature = "chrono")]
1823 {
1824 use chrono::TimeZone;
1825 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1826 let date = base + chrono::Duration::days(days as i64);
1827 let time = Self::intervals_to_time(intervals, scale);
1828 let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
1829 .unwrap_or_else(|| chrono::FixedOffset::east_opt(0).unwrap());
1830 let datetime = offset
1831 .from_local_datetime(&date.and_time(time))
1832 .single()
1833 .unwrap_or_else(|| offset.from_utc_datetime(&date.and_time(time)));
1834 SqlValue::DateTimeOffset(datetime)
1835 }
1836 #[cfg(not(feature = "chrono"))]
1837 {
1838 SqlValue::String(format!(
1839 "DATETIMEOFFSET({days},{intervals},{offset_minutes})"
1840 ))
1841 }
1842 }
1843 }
1844
1845 TypeId::BigVarChar | TypeId::BigChar | TypeId::Text => {
1847 if buf.remaining() < 2 {
1849 return Err(Error::Protocol(
1850 "unexpected EOF reading varchar length".into(),
1851 ));
1852 }
1853 let len = buf.get_u16_le();
1854 if len == 0xFFFF {
1855 SqlValue::Null
1856 } else if buf.remaining() < len as usize {
1857 return Err(Error::Protocol(
1858 "unexpected EOF reading varchar data".into(),
1859 ));
1860 } else {
1861 let data = &buf[..len as usize];
1862 let s = String::from_utf8_lossy(data).into_owned();
1863 buf.advance(len as usize);
1864 SqlValue::String(s)
1865 }
1866 }
1867 TypeId::NVarChar | TypeId::NChar | TypeId::NText => {
1868 if buf.remaining() < 2 {
1870 return Err(Error::Protocol(
1871 "unexpected EOF reading nvarchar length".into(),
1872 ));
1873 }
1874 let len = buf.get_u16_le();
1875 if len == 0xFFFF {
1876 SqlValue::Null
1877 } else if buf.remaining() < len as usize {
1878 return Err(Error::Protocol(
1879 "unexpected EOF reading nvarchar data".into(),
1880 ));
1881 } else {
1882 let data = &buf[..len as usize];
1883 let utf16: Vec<u16> = data
1885 .chunks_exact(2)
1886 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
1887 .collect();
1888 let s = String::from_utf16(&utf16)
1889 .map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
1890 buf.advance(len as usize);
1891 SqlValue::String(s)
1892 }
1893 }
1894
1895 TypeId::BigVarBinary | TypeId::BigBinary | TypeId::Image => {
1897 if buf.remaining() < 2 {
1898 return Err(Error::Protocol(
1899 "unexpected EOF reading varbinary length".into(),
1900 ));
1901 }
1902 let len = buf.get_u16_le();
1903 if len == 0xFFFF {
1904 SqlValue::Null
1905 } else if buf.remaining() < len as usize {
1906 return Err(Error::Protocol(
1907 "unexpected EOF reading varbinary data".into(),
1908 ));
1909 } else {
1910 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
1911 buf.advance(len as usize);
1912 SqlValue::Binary(data)
1913 }
1914 }
1915
1916 TypeId::Guid => {
1918 if buf.remaining() < 1 {
1919 return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
1920 }
1921 let len = buf.get_u8();
1922 if len == 0 {
1923 SqlValue::Null
1924 } else if len != 16 {
1925 return Err(Error::Protocol(format!("invalid GUID length: {len}")));
1926 } else if buf.remaining() < 16 {
1927 return Err(Error::Protocol("unexpected EOF reading GUID".into()));
1928 } else {
1929 let data = bytes::Bytes::copy_from_slice(&buf[..16]);
1931 buf.advance(16);
1932 SqlValue::Binary(data)
1933 }
1934 }
1935
1936 _ => {
1938 if buf.remaining() < 2 {
1940 return Err(Error::Protocol(format!(
1941 "unexpected EOF reading {:?}",
1942 col.type_id
1943 )));
1944 }
1945 let len = buf.get_u16_le();
1946 if len == 0xFFFF {
1947 SqlValue::Null
1948 } else if buf.remaining() < len as usize {
1949 return Err(Error::Protocol(format!(
1950 "unexpected EOF reading {:?} data",
1951 col.type_id
1952 )));
1953 } else {
1954 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
1955 buf.advance(len as usize);
1956 SqlValue::Binary(data)
1957 }
1958 }
1959 };
1960
1961 Ok(value)
1962 }
1963
1964 fn time_bytes_for_scale(scale: u8) -> usize {
1966 match scale {
1967 0..=2 => 3,
1968 3..=4 => 4,
1969 5..=7 => 5,
1970 _ => 5, }
1972 }
1973
1974 #[cfg(feature = "chrono")]
1976 fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
1977 let nanos = match scale {
1987 0 => intervals * 1_000_000_000,
1988 1 => intervals * 100_000_000,
1989 2 => intervals * 10_000_000,
1990 3 => intervals * 1_000_000,
1991 4 => intervals * 100_000,
1992 5 => intervals * 10_000,
1993 6 => intervals * 1_000,
1994 7 => intervals * 100,
1995 _ => intervals * 100,
1996 };
1997
1998 let secs = (nanos / 1_000_000_000) as u32;
1999 let nano_part = (nanos % 1_000_000_000) as u32;
2000
2001 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
2002 .unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap())
2003 }
2004
2005 async fn read_execute_result(&mut self) -> Result<u64> {
2007 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2008
2009 let message = match connection {
2010 ConnectionHandle::Tls(conn) => conn
2011 .read_message()
2012 .await
2013 .map_err(|e| Error::Protocol(e.to_string()))?,
2014 ConnectionHandle::TlsPrelogin(conn) => conn
2015 .read_message()
2016 .await
2017 .map_err(|e| Error::Protocol(e.to_string()))?,
2018 ConnectionHandle::Plain(conn) => conn
2019 .read_message()
2020 .await
2021 .map_err(|e| Error::Protocol(e.to_string()))?,
2022 }
2023 .ok_or(Error::ConnectionClosed)?;
2024
2025 let mut parser = TokenParser::new(message.payload);
2026 let mut rows_affected = 0u64;
2027 let mut current_metadata: Option<ColMetaData> = None;
2028
2029 loop {
2030 let token = parser
2032 .next_token_with_metadata(current_metadata.as_ref())
2033 .map_err(|e| Error::Protocol(e.to_string()))?;
2034
2035 let Some(token) = token else {
2036 break;
2037 };
2038
2039 match token {
2040 Token::ColMetaData(meta) => {
2041 current_metadata = Some(meta);
2043 }
2044 Token::Row(_) | Token::NbcRow(_) => {
2045 }
2048 Token::Done(done) => {
2049 if done.status.error {
2050 return Err(Error::Query("execution failed".to_string()));
2051 }
2052 if done.status.count {
2053 rows_affected = done.row_count;
2054 }
2055 break;
2056 }
2057 Token::DoneProc(done) => {
2058 if done.status.count {
2059 rows_affected = done.row_count;
2060 }
2061 }
2062 Token::DoneInProc(done) => {
2063 if done.status.count {
2064 rows_affected = done.row_count;
2065 }
2066 }
2067 Token::Error(err) => {
2068 return Err(Error::Server {
2069 number: err.number,
2070 state: err.state,
2071 class: err.class,
2072 message: err.message.clone(),
2073 server: if err.server.is_empty() {
2074 None
2075 } else {
2076 Some(err.server.clone())
2077 },
2078 procedure: if err.procedure.is_empty() {
2079 None
2080 } else {
2081 Some(err.procedure.clone())
2082 },
2083 line: err.line as u32,
2084 });
2085 }
2086 Token::Info(info) => {
2087 tracing::info!(
2088 number = info.number,
2089 message = %info.message,
2090 "server info message"
2091 );
2092 }
2093 _ => {}
2094 }
2095 }
2096
2097 Ok(rows_affected)
2098 }
2099
2100 async fn read_transaction_begin_result(&mut self) -> Result<u64> {
2106 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2107
2108 let message = match connection {
2109 ConnectionHandle::Tls(conn) => conn
2110 .read_message()
2111 .await
2112 .map_err(|e| Error::Protocol(e.to_string()))?,
2113 ConnectionHandle::TlsPrelogin(conn) => conn
2114 .read_message()
2115 .await
2116 .map_err(|e| Error::Protocol(e.to_string()))?,
2117 ConnectionHandle::Plain(conn) => conn
2118 .read_message()
2119 .await
2120 .map_err(|e| Error::Protocol(e.to_string()))?,
2121 }
2122 .ok_or(Error::ConnectionClosed)?;
2123
2124 let mut parser = TokenParser::new(message.payload);
2125 let mut transaction_descriptor: u64 = 0;
2126
2127 loop {
2128 let token = parser
2129 .next_token()
2130 .map_err(|e| Error::Protocol(e.to_string()))?;
2131
2132 let Some(token) = token else {
2133 break;
2134 };
2135
2136 match token {
2137 Token::EnvChange(env) => {
2138 if env.env_type == EnvChangeType::BeginTransaction {
2139 if let tds_protocol::token::EnvChangeValue::Binary(ref data) = env.new_value
2142 {
2143 if data.len() >= 8 {
2144 transaction_descriptor = u64::from_le_bytes([
2145 data[0], data[1], data[2], data[3], data[4], data[5], data[6],
2146 data[7],
2147 ]);
2148 tracing::debug!(
2149 transaction_descriptor =
2150 format!("0x{:016X}", transaction_descriptor),
2151 "transaction begun"
2152 );
2153 }
2154 }
2155 }
2156 }
2157 Token::Done(done) => {
2158 if done.status.error {
2159 return Err(Error::Query("BEGIN TRANSACTION failed".to_string()));
2160 }
2161 break;
2162 }
2163 Token::Error(err) => {
2164 return Err(Error::Server {
2165 number: err.number,
2166 state: err.state,
2167 class: err.class,
2168 message: err.message.clone(),
2169 server: if err.server.is_empty() {
2170 None
2171 } else {
2172 Some(err.server.clone())
2173 },
2174 procedure: if err.procedure.is_empty() {
2175 None
2176 } else {
2177 Some(err.procedure.clone())
2178 },
2179 line: err.line as u32,
2180 });
2181 }
2182 Token::Info(info) => {
2183 tracing::info!(
2184 number = info.number,
2185 message = %info.message,
2186 "server info message"
2187 );
2188 }
2189 _ => {}
2190 }
2191 }
2192
2193 Ok(transaction_descriptor)
2194 }
2195}
2196
2197impl Client<Ready> {
2198 pub async fn query<'a>(
2223 &'a mut self,
2224 sql: &str,
2225 params: &[&(dyn crate::ToSql + Sync)],
2226 ) -> Result<QueryStream<'a>> {
2227 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
2228
2229 #[cfg(feature = "otel")]
2230 let instrumentation = self.instrumentation.clone();
2231 #[cfg(feature = "otel")]
2232 let mut span = instrumentation.query_span(sql);
2233
2234 let result = async {
2235 if params.is_empty() {
2236 self.send_sql_batch(sql).await?;
2238 } else {
2239 let rpc_params = Self::convert_params(params)?;
2241 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2242 self.send_rpc(&rpc).await?;
2243 }
2244
2245 self.read_query_response().await
2247 }
2248 .await;
2249
2250 #[cfg(feature = "otel")]
2251 match &result {
2252 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2253 Err(e) => InstrumentationContext::record_error(&mut span, e),
2254 }
2255
2256 #[cfg(feature = "otel")]
2258 drop(span);
2259
2260 let (columns, rows) = result?;
2261 Ok(QueryStream::new(columns, rows))
2262 }
2263
2264 pub async fn query_with_timeout<'a>(
2291 &'a mut self,
2292 sql: &str,
2293 params: &[&(dyn crate::ToSql + Sync)],
2294 timeout_duration: std::time::Duration,
2295 ) -> Result<QueryStream<'a>> {
2296 timeout(timeout_duration, self.query(sql, params))
2297 .await
2298 .map_err(|_| Error::CommandTimeout)?
2299 }
2300
2301 pub async fn query_multiple<'a>(
2328 &'a mut self,
2329 sql: &str,
2330 params: &[&(dyn crate::ToSql + Sync)],
2331 ) -> Result<MultiResultStream<'a>> {
2332 tracing::debug!(
2333 sql = sql,
2334 params_count = params.len(),
2335 "executing multi-result query"
2336 );
2337
2338 if params.is_empty() {
2339 self.send_sql_batch(sql).await?;
2341 } else {
2342 let rpc_params = Self::convert_params(params)?;
2344 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2345 self.send_rpc(&rpc).await?;
2346 }
2347
2348 let result_sets = self.read_multi_result_response().await?;
2350 Ok(MultiResultStream::new(result_sets))
2351 }
2352
2353 async fn read_multi_result_response(&mut self) -> Result<Vec<crate::stream::ResultSet>> {
2355 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2356
2357 let message = match connection {
2358 ConnectionHandle::Tls(conn) => conn
2359 .read_message()
2360 .await
2361 .map_err(|e| Error::Protocol(e.to_string()))?,
2362 ConnectionHandle::TlsPrelogin(conn) => conn
2363 .read_message()
2364 .await
2365 .map_err(|e| Error::Protocol(e.to_string()))?,
2366 ConnectionHandle::Plain(conn) => conn
2367 .read_message()
2368 .await
2369 .map_err(|e| Error::Protocol(e.to_string()))?,
2370 }
2371 .ok_or(Error::ConnectionClosed)?;
2372
2373 let mut parser = TokenParser::new(message.payload);
2374 let mut result_sets: Vec<crate::stream::ResultSet> = Vec::new();
2375 let mut current_columns: Vec<crate::row::Column> = Vec::new();
2376 let mut current_rows: Vec<crate::row::Row> = Vec::new();
2377 let mut protocol_metadata: Option<ColMetaData> = None;
2378
2379 loop {
2380 let token = parser
2381 .next_token_with_metadata(protocol_metadata.as_ref())
2382 .map_err(|e| Error::Protocol(e.to_string()))?;
2383
2384 let Some(token) = token else {
2385 break;
2386 };
2387
2388 match token {
2389 Token::ColMetaData(meta) => {
2390 if !current_columns.is_empty() {
2392 result_sets.push(crate::stream::ResultSet::new(
2393 std::mem::take(&mut current_columns),
2394 std::mem::take(&mut current_rows),
2395 ));
2396 }
2397
2398 current_columns = meta
2400 .columns
2401 .iter()
2402 .enumerate()
2403 .map(|(i, col)| {
2404 let type_name = format!("{:?}", col.type_id);
2405 let mut column = crate::row::Column::new(&col.name, i, type_name)
2406 .with_nullable(col.flags & 0x01 != 0);
2407
2408 if let Some(max_len) = col.type_info.max_length {
2409 column = column.with_max_length(max_len);
2410 }
2411 if let (Some(prec), Some(scale)) =
2412 (col.type_info.precision, col.type_info.scale)
2413 {
2414 column = column.with_precision_scale(prec, scale);
2415 }
2416 column
2417 })
2418 .collect();
2419
2420 tracing::debug!(
2421 columns = current_columns.len(),
2422 result_set = result_sets.len(),
2423 "received column metadata for result set"
2424 );
2425 protocol_metadata = Some(meta);
2426 }
2427 Token::Row(raw_row) => {
2428 if let Some(ref meta) = protocol_metadata {
2429 let row = Self::convert_raw_row(&raw_row, meta, ¤t_columns)?;
2430 current_rows.push(row);
2431 }
2432 }
2433 Token::NbcRow(nbc_row) => {
2434 if let Some(ref meta) = protocol_metadata {
2435 let row = Self::convert_nbc_row(&nbc_row, meta, ¤t_columns)?;
2436 current_rows.push(row);
2437 }
2438 }
2439 Token::Error(err) => {
2440 return Err(Error::Server {
2441 number: err.number,
2442 state: err.state,
2443 class: err.class,
2444 message: err.message.clone(),
2445 server: if err.server.is_empty() {
2446 None
2447 } else {
2448 Some(err.server.clone())
2449 },
2450 procedure: if err.procedure.is_empty() {
2451 None
2452 } else {
2453 Some(err.procedure.clone())
2454 },
2455 line: err.line as u32,
2456 });
2457 }
2458 Token::Done(done) => {
2459 if done.status.error {
2460 return Err(Error::Query("query failed".to_string()));
2461 }
2462
2463 if !current_columns.is_empty() {
2465 result_sets.push(crate::stream::ResultSet::new(
2466 std::mem::take(&mut current_columns),
2467 std::mem::take(&mut current_rows),
2468 ));
2469 protocol_metadata = None;
2470 }
2471
2472 if !done.status.more {
2474 tracing::debug!(result_sets = result_sets.len(), "all result sets parsed");
2475 break;
2476 }
2477 }
2478 Token::DoneInProc(done) => {
2479 if done.status.error {
2480 return Err(Error::Query("query failed".to_string()));
2481 }
2482
2483 if !current_columns.is_empty() {
2485 result_sets.push(crate::stream::ResultSet::new(
2486 std::mem::take(&mut current_columns),
2487 std::mem::take(&mut current_rows),
2488 ));
2489 protocol_metadata = None;
2490 }
2491
2492 if !done.status.more {
2494 }
2496 }
2497 Token::DoneProc(done) => {
2498 if done.status.error {
2499 return Err(Error::Query("query failed".to_string()));
2500 }
2501 }
2503 Token::Info(info) => {
2504 tracing::debug!(
2505 number = info.number,
2506 message = %info.message,
2507 "server info message"
2508 );
2509 }
2510 _ => {}
2511 }
2512 }
2513
2514 if !current_columns.is_empty() {
2516 result_sets.push(crate::stream::ResultSet::new(current_columns, current_rows));
2517 }
2518
2519 Ok(result_sets)
2520 }
2521
2522 pub async fn execute(
2526 &mut self,
2527 sql: &str,
2528 params: &[&(dyn crate::ToSql + Sync)],
2529 ) -> Result<u64> {
2530 tracing::debug!(
2531 sql = sql,
2532 params_count = params.len(),
2533 "executing statement"
2534 );
2535
2536 #[cfg(feature = "otel")]
2537 let instrumentation = self.instrumentation.clone();
2538 #[cfg(feature = "otel")]
2539 let mut span = instrumentation.query_span(sql);
2540
2541 let result = async {
2542 if params.is_empty() {
2543 self.send_sql_batch(sql).await?;
2545 } else {
2546 let rpc_params = Self::convert_params(params)?;
2548 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2549 self.send_rpc(&rpc).await?;
2550 }
2551
2552 self.read_execute_result().await
2554 }
2555 .await;
2556
2557 #[cfg(feature = "otel")]
2558 match &result {
2559 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
2560 Err(e) => InstrumentationContext::record_error(&mut span, e),
2561 }
2562
2563 #[cfg(feature = "otel")]
2565 drop(span);
2566
2567 result
2568 }
2569
2570 pub async fn execute_with_timeout(
2597 &mut self,
2598 sql: &str,
2599 params: &[&(dyn crate::ToSql + Sync)],
2600 timeout_duration: std::time::Duration,
2601 ) -> Result<u64> {
2602 timeout(timeout_duration, self.execute(sql, params))
2603 .await
2604 .map_err(|_| Error::CommandTimeout)?
2605 }
2606
2607 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
2614 tracing::debug!("beginning transaction");
2615
2616 #[cfg(feature = "otel")]
2617 let instrumentation = self.instrumentation.clone();
2618 #[cfg(feature = "otel")]
2619 let mut span = instrumentation.transaction_span("BEGIN");
2620
2621 let result = async {
2623 self.send_sql_batch("BEGIN TRANSACTION").await?;
2624 self.read_transaction_begin_result().await
2625 }
2626 .await;
2627
2628 #[cfg(feature = "otel")]
2629 match &result {
2630 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2631 Err(e) => InstrumentationContext::record_error(&mut span, e),
2632 }
2633
2634 #[cfg(feature = "otel")]
2636 drop(span);
2637
2638 let transaction_descriptor = result?;
2639
2640 Ok(Client {
2641 config: self.config,
2642 _state: PhantomData,
2643 connection: self.connection,
2644 server_version: self.server_version,
2645 current_database: self.current_database,
2646 statement_cache: self.statement_cache,
2647 transaction_descriptor, #[cfg(feature = "otel")]
2649 instrumentation: self.instrumentation,
2650 })
2651 }
2652
2653 pub async fn begin_transaction_with_isolation(
2668 mut self,
2669 isolation_level: crate::transaction::IsolationLevel,
2670 ) -> Result<Client<InTransaction>> {
2671 tracing::debug!(
2672 isolation_level = %isolation_level.name(),
2673 "beginning transaction with isolation level"
2674 );
2675
2676 #[cfg(feature = "otel")]
2677 let instrumentation = self.instrumentation.clone();
2678 #[cfg(feature = "otel")]
2679 let mut span = instrumentation.transaction_span("BEGIN");
2680
2681 let result = async {
2683 self.send_sql_batch(isolation_level.as_sql()).await?;
2684 self.read_execute_result().await?;
2685
2686 self.send_sql_batch("BEGIN TRANSACTION").await?;
2688 self.read_transaction_begin_result().await
2689 }
2690 .await;
2691
2692 #[cfg(feature = "otel")]
2693 match &result {
2694 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2695 Err(e) => InstrumentationContext::record_error(&mut span, e),
2696 }
2697
2698 #[cfg(feature = "otel")]
2699 drop(span);
2700
2701 let transaction_descriptor = result?;
2702
2703 Ok(Client {
2704 config: self.config,
2705 _state: PhantomData,
2706 connection: self.connection,
2707 server_version: self.server_version,
2708 current_database: self.current_database,
2709 statement_cache: self.statement_cache,
2710 transaction_descriptor,
2711 #[cfg(feature = "otel")]
2712 instrumentation: self.instrumentation,
2713 })
2714 }
2715
2716 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
2721 tracing::debug!(sql = sql, "executing simple query");
2722
2723 self.send_sql_batch(sql).await?;
2725
2726 let _ = self.read_execute_result().await?;
2728
2729 Ok(())
2730 }
2731
2732 pub async fn close(self) -> Result<()> {
2734 tracing::debug!("closing connection");
2735 Ok(())
2736 }
2737
2738 #[must_use]
2740 pub fn database(&self) -> Option<&str> {
2741 self.config.database.as_deref()
2742 }
2743
2744 #[must_use]
2746 pub fn host(&self) -> &str {
2747 &self.config.host
2748 }
2749
2750 #[must_use]
2752 pub fn port(&self) -> u16 {
2753 self.config.port
2754 }
2755
2756 #[must_use]
2778 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
2779 let connection = self
2780 .connection
2781 .as_ref()
2782 .expect("connection should be present");
2783 match connection {
2784 ConnectionHandle::Tls(conn) => {
2785 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
2786 }
2787 ConnectionHandle::TlsPrelogin(conn) => {
2788 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
2789 }
2790 ConnectionHandle::Plain(conn) => {
2791 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
2792 }
2793 }
2794 }
2795}
2796
2797impl Client<InTransaction> {
2798 pub async fn query<'a>(
2802 &'a mut self,
2803 sql: &str,
2804 params: &[&(dyn crate::ToSql + Sync)],
2805 ) -> Result<QueryStream<'a>> {
2806 tracing::debug!(
2807 sql = sql,
2808 params_count = params.len(),
2809 "executing query in transaction"
2810 );
2811
2812 #[cfg(feature = "otel")]
2813 let instrumentation = self.instrumentation.clone();
2814 #[cfg(feature = "otel")]
2815 let mut span = instrumentation.query_span(sql);
2816
2817 let result = async {
2818 if params.is_empty() {
2819 self.send_sql_batch(sql).await?;
2821 } else {
2822 let rpc_params = Self::convert_params(params)?;
2824 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2825 self.send_rpc(&rpc).await?;
2826 }
2827
2828 self.read_query_response().await
2830 }
2831 .await;
2832
2833 #[cfg(feature = "otel")]
2834 match &result {
2835 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2836 Err(e) => InstrumentationContext::record_error(&mut span, e),
2837 }
2838
2839 #[cfg(feature = "otel")]
2841 drop(span);
2842
2843 let (columns, rows) = result?;
2844 Ok(QueryStream::new(columns, rows))
2845 }
2846
2847 pub async fn execute(
2851 &mut self,
2852 sql: &str,
2853 params: &[&(dyn crate::ToSql + Sync)],
2854 ) -> Result<u64> {
2855 tracing::debug!(
2856 sql = sql,
2857 params_count = params.len(),
2858 "executing statement in transaction"
2859 );
2860
2861 #[cfg(feature = "otel")]
2862 let instrumentation = self.instrumentation.clone();
2863 #[cfg(feature = "otel")]
2864 let mut span = instrumentation.query_span(sql);
2865
2866 let result = async {
2867 if params.is_empty() {
2868 self.send_sql_batch(sql).await?;
2870 } else {
2871 let rpc_params = Self::convert_params(params)?;
2873 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2874 self.send_rpc(&rpc).await?;
2875 }
2876
2877 self.read_execute_result().await
2879 }
2880 .await;
2881
2882 #[cfg(feature = "otel")]
2883 match &result {
2884 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
2885 Err(e) => InstrumentationContext::record_error(&mut span, e),
2886 }
2887
2888 #[cfg(feature = "otel")]
2890 drop(span);
2891
2892 result
2893 }
2894
2895 pub async fn query_with_timeout<'a>(
2899 &'a mut self,
2900 sql: &str,
2901 params: &[&(dyn crate::ToSql + Sync)],
2902 timeout_duration: std::time::Duration,
2903 ) -> Result<QueryStream<'a>> {
2904 timeout(timeout_duration, self.query(sql, params))
2905 .await
2906 .map_err(|_| Error::CommandTimeout)?
2907 }
2908
2909 pub async fn execute_with_timeout(
2913 &mut self,
2914 sql: &str,
2915 params: &[&(dyn crate::ToSql + Sync)],
2916 timeout_duration: std::time::Duration,
2917 ) -> Result<u64> {
2918 timeout(timeout_duration, self.execute(sql, params))
2919 .await
2920 .map_err(|_| Error::CommandTimeout)?
2921 }
2922
2923 pub async fn commit(mut self) -> Result<Client<Ready>> {
2927 tracing::debug!("committing transaction");
2928
2929 #[cfg(feature = "otel")]
2930 let instrumentation = self.instrumentation.clone();
2931 #[cfg(feature = "otel")]
2932 let mut span = instrumentation.transaction_span("COMMIT");
2933
2934 let result = async {
2936 self.send_sql_batch("COMMIT TRANSACTION").await?;
2937 self.read_execute_result().await
2938 }
2939 .await;
2940
2941 #[cfg(feature = "otel")]
2942 match &result {
2943 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2944 Err(e) => InstrumentationContext::record_error(&mut span, e),
2945 }
2946
2947 #[cfg(feature = "otel")]
2949 drop(span);
2950
2951 result?;
2952
2953 Ok(Client {
2954 config: self.config,
2955 _state: PhantomData,
2956 connection: self.connection,
2957 server_version: self.server_version,
2958 current_database: self.current_database,
2959 statement_cache: self.statement_cache,
2960 transaction_descriptor: 0, #[cfg(feature = "otel")]
2962 instrumentation: self.instrumentation,
2963 })
2964 }
2965
2966 pub async fn rollback(mut self) -> Result<Client<Ready>> {
2970 tracing::debug!("rolling back transaction");
2971
2972 #[cfg(feature = "otel")]
2973 let instrumentation = self.instrumentation.clone();
2974 #[cfg(feature = "otel")]
2975 let mut span = instrumentation.transaction_span("ROLLBACK");
2976
2977 let result = async {
2979 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
2980 self.read_execute_result().await
2981 }
2982 .await;
2983
2984 #[cfg(feature = "otel")]
2985 match &result {
2986 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2987 Err(e) => InstrumentationContext::record_error(&mut span, e),
2988 }
2989
2990 #[cfg(feature = "otel")]
2992 drop(span);
2993
2994 result?;
2995
2996 Ok(Client {
2997 config: self.config,
2998 _state: PhantomData,
2999 connection: self.connection,
3000 server_version: self.server_version,
3001 current_database: self.current_database,
3002 statement_cache: self.statement_cache,
3003 transaction_descriptor: 0, #[cfg(feature = "otel")]
3005 instrumentation: self.instrumentation,
3006 })
3007 }
3008
3009 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
3026 validate_identifier(name)?;
3027 tracing::debug!(name = name, "creating savepoint");
3028
3029 let sql = format!("SAVE TRANSACTION {}", name);
3032 self.send_sql_batch(&sql).await?;
3033 self.read_execute_result().await?;
3034
3035 Ok(SavePoint::new(name.to_string()))
3036 }
3037
3038 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
3053 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
3054
3055 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
3058 self.send_sql_batch(&sql).await?;
3059 self.read_execute_result().await?;
3060
3061 Ok(())
3062 }
3063
3064 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
3070 tracing::debug!(name = savepoint.name(), "releasing savepoint");
3071
3072 drop(savepoint);
3076 Ok(())
3077 }
3078
3079 #[must_use]
3083 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3084 let connection = self
3085 .connection
3086 .as_ref()
3087 .expect("connection should be present");
3088 match connection {
3089 ConnectionHandle::Tls(conn) => {
3090 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3091 }
3092 ConnectionHandle::TlsPrelogin(conn) => {
3093 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3094 }
3095 ConnectionHandle::Plain(conn) => {
3096 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3097 }
3098 }
3099 }
3100}
3101
3102fn validate_identifier(name: &str) -> Result<()> {
3104 use once_cell::sync::Lazy;
3105 use regex::Regex;
3106
3107 static IDENTIFIER_RE: Lazy<Regex> =
3108 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
3109
3110 if name.is_empty() {
3111 return Err(Error::InvalidIdentifier(
3112 "identifier cannot be empty".into(),
3113 ));
3114 }
3115
3116 if !IDENTIFIER_RE.is_match(name) {
3117 return Err(Error::InvalidIdentifier(format!(
3118 "invalid identifier '{}': must start with letter/underscore, \
3119 contain only alphanumerics/_/@/#/$, and be 1-128 characters",
3120 name
3121 )));
3122 }
3123
3124 Ok(())
3125}
3126
3127impl<S: ConnectionState> std::fmt::Debug for Client<S> {
3128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3129 f.debug_struct("Client")
3130 .field("host", &self.config.host)
3131 .field("port", &self.config.port)
3132 .field("database", &self.config.database)
3133 .finish()
3134 }
3135}
3136
3137#[cfg(test)]
3138#[allow(clippy::unwrap_used)]
3139mod tests {
3140 use super::*;
3141
3142 #[test]
3143 fn test_validate_identifier_valid() {
3144 assert!(validate_identifier("my_table").is_ok());
3145 assert!(validate_identifier("Table123").is_ok());
3146 assert!(validate_identifier("_private").is_ok());
3147 assert!(validate_identifier("sp_test").is_ok());
3148 }
3149
3150 #[test]
3151 fn test_validate_identifier_invalid() {
3152 assert!(validate_identifier("").is_err());
3153 assert!(validate_identifier("123abc").is_err());
3154 assert!(validate_identifier("table-name").is_err());
3155 assert!(validate_identifier("table name").is_err());
3156 assert!(validate_identifier("table;DROP TABLE users").is_err());
3157 }
3158}