1#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use bytes::BytesMut;
11use mssql_codec::connection::Connection;
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode, TlsStream};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::{MAX_PACKET_SIZE, PacketType};
15use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
16use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
17use tds_protocol::token::{
18 ColMetaData, ColumnData, EnvChange, EnvChangeType, NbcRow, RawRow, Token, TokenParser,
19};
20use tokio::net::TcpStream;
21use tokio::time::timeout;
22
23use crate::config::Config;
24use crate::error::{Error, Result};
25#[cfg(feature = "otel")]
26use crate::instrumentation::InstrumentationContext;
27use crate::state::{ConnectionState, Disconnected, InTransaction, Ready};
28use crate::statement_cache::StatementCache;
29use crate::stream::{MultiResultStream, QueryStream};
30use crate::transaction::SavePoint;
31
32pub struct Client<S: ConnectionState> {
38 config: Config,
39 _state: PhantomData<S>,
40 connection: Option<ConnectionHandle>,
42 server_version: Option<u32>,
44 current_database: Option<String>,
46 statement_cache: StatementCache,
48 transaction_descriptor: u64,
52 #[cfg(feature = "otel")]
54 instrumentation: InstrumentationContext,
55}
56
57#[allow(dead_code)] enum ConnectionHandle {
65 Tls(Connection<TlsStream<TcpStream>>),
67 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
69 Plain(Connection<TcpStream>),
71}
72
73impl Client<Disconnected> {
74 pub async fn connect(config: Config) -> Result<Client<Ready>> {
85 let max_redirects = config.redirect.max_redirects;
86 let follow_redirects = config.redirect.follow_redirects;
87 let mut attempts = 0;
88 let mut current_config = config;
89
90 loop {
91 attempts += 1;
92 if attempts > max_redirects + 1 {
93 return Err(Error::TooManyRedirects { max: max_redirects });
94 }
95
96 match Self::try_connect(¤t_config).await {
97 Ok(client) => return Ok(client),
98 Err(Error::Routing { host, port }) => {
99 if !follow_redirects {
100 return Err(Error::Routing { host, port });
101 }
102 tracing::info!(
103 host = %host,
104 port = port,
105 attempt = attempts,
106 max_redirects = max_redirects,
107 "following Azure SQL routing redirect"
108 );
109 current_config = current_config.with_host(&host).with_port(port);
110 continue;
111 }
112 Err(e) => return Err(e),
113 }
114 }
115 }
116
117 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
118 tracing::info!(
119 host = %config.host,
120 port = config.port,
121 database = ?config.database,
122 "connecting to SQL Server"
123 );
124
125 let addr = format!("{}:{}", config.host, config.port);
126
127 tracing::debug!("establishing TCP connection to {}", addr);
129 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
130 .await
131 .map_err(|_| Error::ConnectTimeout)?
132 .map_err(|e| Error::Io(Arc::new(e)))?;
133
134 tcp_stream
136 .set_nodelay(true)
137 .map_err(|e| Error::Io(Arc::new(e)))?;
138
139 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
141
142 if tls_mode.is_tls_first() {
144 return Self::connect_tds_8(config, tcp_stream).await;
145 }
146
147 Self::connect_tds_7x(config, tcp_stream).await
149 }
150
151 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
155 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
156
157 let tls_config = TlsConfig::new()
159 .strict_mode(true)
160 .trust_server_certificate(config.trust_server_certificate);
161
162 let tls_connector = TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
163
164 let tls_stream = timeout(
166 config.timeouts.tls_timeout,
167 tls_connector.connect(tcp_stream, &config.host),
168 )
169 .await
170 .map_err(|_| Error::TlsTimeout)?
171 .map_err(|e| Error::Tls(e.to_string()))?;
172
173 tracing::debug!("TLS handshake completed (strict mode)");
174
175 let mut connection = Connection::new(tls_stream);
177
178 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
180 Self::send_prelogin(&mut connection, &prelogin).await?;
181 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
182
183 let login = Self::build_login7(config);
185 Self::send_login7(&mut connection, &login).await?;
186
187 let (server_version, current_database, routing) =
189 Self::process_login_response(&mut connection).await?;
190
191 if let Some((host, port)) = routing {
193 return Err(Error::Routing { host, port });
194 }
195
196 Ok(Client {
197 config: config.clone(),
198 _state: PhantomData,
199 connection: Some(ConnectionHandle::Tls(connection)),
200 server_version,
201 current_database: current_database.clone(),
202 statement_cache: StatementCache::with_default_size(),
203 transaction_descriptor: 0, #[cfg(feature = "otel")]
205 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
206 .with_database(current_database.unwrap_or_default()),
207 })
208 }
209
210 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
218 use bytes::BufMut;
219 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
220 use tokio::io::{AsyncReadExt, AsyncWriteExt};
221
222 tracing::debug!("using TDS 7.x flow (PreLogin first)");
223
224 let client_encryption = if config.encrypt {
227 EncryptionLevel::On
228 } else {
229 EncryptionLevel::Off
230 };
231 let prelogin = Self::build_prelogin(config, client_encryption);
232 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
233 let prelogin_bytes = prelogin.encode();
234
235 let header = PacketHeader::new(
237 PacketType::PreLogin,
238 PacketStatus::END_OF_MESSAGE,
239 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
240 );
241
242 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
243 header.encode(&mut packet_buf);
244 packet_buf.put_slice(&prelogin_bytes);
245
246 tcp_stream
247 .write_all(&packet_buf)
248 .await
249 .map_err(|e| Error::Io(Arc::new(e)))?;
250
251 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
253 tcp_stream
254 .read_exact(&mut header_buf)
255 .await
256 .map_err(|e| Error::Io(Arc::new(e)))?;
257
258 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
259 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
260
261 let mut response_buf = vec![0u8; payload_length];
262 tcp_stream
263 .read_exact(&mut response_buf)
264 .await
265 .map_err(|e| Error::Io(Arc::new(e)))?;
266
267 let prelogin_response =
268 PreLogin::decode(&response_buf[..]).map_err(|e| Error::Protocol(e.to_string()))?;
269
270 let server_encryption = prelogin_response.encryption;
272 tracing::debug!(encryption = ?server_encryption, "server encryption level");
273
274 let negotiated_encryption = match (client_encryption, server_encryption) {
280 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
281 EncryptionLevel::NotSupported
282 }
283 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
284 (EncryptionLevel::On, EncryptionLevel::Off)
285 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
286 return Err(Error::Protocol(
287 "Server does not support requested encryption level".to_string(),
288 ));
289 }
290 _ => EncryptionLevel::On,
291 };
292
293 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
296
297 if use_tls {
298 let tls_config =
301 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
302
303 let tls_connector =
304 TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
305
306 let mut tls_stream = timeout(
308 config.timeouts.tls_timeout,
309 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
310 )
311 .await
312 .map_err(|_| Error::TlsTimeout)?
313 .map_err(|e| Error::Tls(e.to_string()))?;
314
315 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
316
317 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
319
320 if login_only_encryption {
321 use tokio::io::AsyncWriteExt;
329
330 let login = Self::build_login7(config);
332 let login_payload = login.encode();
333
334 let max_packet = MAX_PACKET_SIZE;
336 let max_payload = max_packet - PACKET_HEADER_SIZE;
337 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
338 let total_chunks = chunks.len();
339
340 for (i, chunk) in chunks.into_iter().enumerate() {
341 let is_last = i == total_chunks - 1;
342 let status = if is_last {
343 PacketStatus::END_OF_MESSAGE
344 } else {
345 PacketStatus::NORMAL
346 };
347
348 let header = PacketHeader::new(
349 PacketType::Tds7Login,
350 status,
351 (PACKET_HEADER_SIZE + chunk.len()) as u16,
352 );
353
354 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
355 header.encode(&mut packet_buf);
356 packet_buf.put_slice(chunk);
357
358 tls_stream
359 .write_all(&packet_buf)
360 .await
361 .map_err(|e| Error::Io(Arc::new(e)))?;
362 }
363
364 tls_stream
366 .flush()
367 .await
368 .map_err(|e| Error::Io(Arc::new(e)))?;
369
370 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
371
372 let (wrapper, _client_conn) = tls_stream.into_inner();
376 let tcp_stream = wrapper.into_inner();
377
378 let mut connection = Connection::new(tcp_stream);
380
381 let (server_version, current_database, routing) =
383 Self::process_login_response(&mut connection).await?;
384
385 if let Some((host, port)) = routing {
387 return Err(Error::Routing { host, port });
388 }
389
390 Ok(Client {
392 config: config.clone(),
393 _state: PhantomData,
394 connection: Some(ConnectionHandle::Plain(connection)),
395 server_version,
396 current_database: current_database.clone(),
397 statement_cache: StatementCache::with_default_size(),
398 transaction_descriptor: 0, #[cfg(feature = "otel")]
400 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
401 .with_database(current_database.unwrap_or_default()),
402 })
403 } else {
404 let mut connection = Connection::new(tls_stream);
407
408 let login = Self::build_login7(config);
410 Self::send_login7(&mut connection, &login).await?;
411
412 let (server_version, current_database, routing) =
414 Self::process_login_response(&mut connection).await?;
415
416 if let Some((host, port)) = routing {
418 return Err(Error::Routing { host, port });
419 }
420
421 Ok(Client {
422 config: config.clone(),
423 _state: PhantomData,
424 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
425 server_version,
426 current_database: current_database.clone(),
427 statement_cache: StatementCache::with_default_size(),
428 transaction_descriptor: 0, #[cfg(feature = "otel")]
430 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
431 .with_database(current_database.unwrap_or_default()),
432 })
433 }
434 } else {
435 tracing::warn!(
437 "Connecting without TLS encryption. This is insecure and should only be \
438 used for development/testing on trusted networks."
439 );
440
441 let login = Self::build_login7(config);
443 let login_bytes = login.encode();
444 tracing::debug!("Login7 packet built: {} bytes", login_bytes.len(),);
445 tracing::debug!(
447 "Login7 fixed header (94 bytes): {:02X?}",
448 &login_bytes[..login_bytes.len().min(94)]
449 );
450 if login_bytes.len() > 94 {
452 tracing::debug!(
453 "Login7 variable data ({} bytes): {:02X?}",
454 login_bytes.len() - 94,
455 &login_bytes[94..]
456 );
457 }
458
459 let login_header = PacketHeader::new(
461 PacketType::Tds7Login,
462 PacketStatus::END_OF_MESSAGE,
463 (PACKET_HEADER_SIZE + login_bytes.len()) as u16,
464 )
465 .with_packet_id(1);
466 let mut login_packet_buf =
467 BytesMut::with_capacity(PACKET_HEADER_SIZE + login_bytes.len());
468 login_header.encode(&mut login_packet_buf);
469 login_packet_buf.put_slice(&login_bytes);
470
471 tracing::debug!(
472 "Sending Login7 packet: {} bytes total, header: {:02X?}",
473 login_packet_buf.len(),
474 &login_packet_buf[..PACKET_HEADER_SIZE]
475 );
476 tcp_stream
477 .write_all(&login_packet_buf)
478 .await
479 .map_err(|e| Error::Io(Arc::new(e)))?;
480 tcp_stream
481 .flush()
482 .await
483 .map_err(|e| Error::Io(Arc::new(e)))?;
484 tracing::debug!("Login7 sent and flushed over raw TCP");
485
486 let mut response_header_buf = [0u8; PACKET_HEADER_SIZE];
488 tcp_stream
489 .read_exact(&mut response_header_buf)
490 .await
491 .map_err(|e| Error::Io(Arc::new(e)))?;
492
493 let response_type = response_header_buf[0];
494 let response_length =
495 u16::from_be_bytes([response_header_buf[2], response_header_buf[3]]) as usize;
496 tracing::debug!(
497 "Response header: type={:#04X}, length={}",
498 response_type,
499 response_length
500 );
501
502 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
504 let mut response_payload = vec![0u8; payload_length];
505 tcp_stream
506 .read_exact(&mut response_payload)
507 .await
508 .map_err(|e| Error::Io(Arc::new(e)))?;
509 tracing::debug!(
510 "Response payload: {} bytes, first 32: {:02X?}",
511 response_payload.len(),
512 &response_payload[..response_payload.len().min(32)]
513 );
514
515 let connection = Connection::new(tcp_stream);
517
518 let response_bytes = bytes::Bytes::from(response_payload);
520 let mut parser = TokenParser::new(response_bytes);
521 let mut server_version = None;
522 let mut current_database = None;
523 let routing = None;
524
525 while let Some(token) = parser
526 .next_token()
527 .map_err(|e| Error::Protocol(e.to_string()))?
528 {
529 match token {
530 Token::LoginAck(ack) => {
531 tracing::info!(
532 version = ack.tds_version,
533 interface = ack.interface,
534 prog_name = %ack.prog_name,
535 "login acknowledged"
536 );
537 server_version = Some(ack.tds_version);
538 }
539 Token::EnvChange(env) => {
540 Self::process_env_change(&env, &mut current_database, &mut None);
541 }
542 Token::Error(err) => {
543 return Err(Error::Server {
544 number: err.number,
545 state: err.state,
546 class: err.class,
547 message: err.message.clone(),
548 server: if err.server.is_empty() {
549 None
550 } else {
551 Some(err.server.clone())
552 },
553 procedure: if err.procedure.is_empty() {
554 None
555 } else {
556 Some(err.procedure.clone())
557 },
558 line: err.line as u32,
559 });
560 }
561 Token::Info(info) => {
562 tracing::info!(
563 number = info.number,
564 message = %info.message,
565 "server info message"
566 );
567 }
568 Token::Done(done) => {
569 if done.status.error {
570 return Err(Error::Protocol("login failed".to_string()));
571 }
572 break;
573 }
574 _ => {}
575 }
576 }
577
578 if let Some((host, port)) = routing {
580 return Err(Error::Routing { host, port });
581 }
582
583 Ok(Client {
584 config: config.clone(),
585 _state: PhantomData,
586 connection: Some(ConnectionHandle::Plain(connection)),
587 server_version,
588 current_database: current_database.clone(),
589 statement_cache: StatementCache::with_default_size(),
590 transaction_descriptor: 0, #[cfg(feature = "otel")]
592 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
593 .with_database(current_database.unwrap_or_default()),
594 })
595 }
596 }
597
598 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
600 let mut prelogin = PreLogin::new().with_encryption(encryption);
601
602 if config.mars {
603 prelogin = prelogin.with_mars(true);
604 }
605
606 if let Some(ref instance) = config.instance {
607 prelogin = prelogin.with_instance(instance);
608 }
609
610 prelogin
611 }
612
613 fn build_login7(config: &Config) -> Login7 {
615 let mut login = Login7::new()
616 .with_packet_size(config.packet_size as u32)
617 .with_app_name(&config.application_name)
618 .with_server_name(&config.host)
619 .with_hostname(&config.host);
620
621 if let Some(ref database) = config.database {
622 login = login.with_database(database);
623 }
624
625 match &config.credentials {
627 mssql_auth::Credentials::SqlServer { username, password } => {
628 login = login.with_sql_auth(username.as_ref(), password.as_ref());
629 }
630 _ => {}
632 }
633
634 login
635 }
636
637 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
639 where
640 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
641 {
642 let payload = prelogin.encode();
643 let max_packet = MAX_PACKET_SIZE;
644
645 connection
646 .send_message(PacketType::PreLogin, payload, max_packet)
647 .await
648 .map_err(|e| Error::Protocol(e.to_string()))
649 }
650
651 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
653 where
654 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
655 {
656 let message = connection
657 .read_message()
658 .await
659 .map_err(|e| Error::Protocol(e.to_string()))?
660 .ok_or(Error::ConnectionClosed)?;
661
662 PreLogin::decode(&message.payload[..]).map_err(|e| Error::Protocol(e.to_string()))
663 }
664
665 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
667 where
668 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
669 {
670 let payload = login.encode();
671 let max_packet = MAX_PACKET_SIZE;
672
673 connection
674 .send_message(PacketType::Tds7Login, payload, max_packet)
675 .await
676 .map_err(|e| Error::Protocol(e.to_string()))
677 }
678
679 async fn process_login_response<T>(
683 connection: &mut Connection<T>,
684 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
685 where
686 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
687 {
688 let message = connection
689 .read_message()
690 .await
691 .map_err(|e| Error::Protocol(e.to_string()))?
692 .ok_or(Error::ConnectionClosed)?;
693
694 let response_bytes = message.payload;
695
696 let mut parser = TokenParser::new(response_bytes);
697 let mut server_version = None;
698 let mut database = None;
699 let mut routing = None;
700
701 while let Some(token) = parser
702 .next_token()
703 .map_err(|e| Error::Protocol(e.to_string()))?
704 {
705 match token {
706 Token::LoginAck(ack) => {
707 tracing::info!(
708 version = ack.tds_version,
709 interface = ack.interface,
710 prog_name = %ack.prog_name,
711 "login acknowledged"
712 );
713 server_version = Some(ack.tds_version);
714 }
715 Token::EnvChange(env) => {
716 Self::process_env_change(&env, &mut database, &mut routing);
717 }
718 Token::Error(err) => {
719 return Err(Error::Server {
720 number: err.number,
721 state: err.state,
722 class: err.class,
723 message: err.message.clone(),
724 server: if err.server.is_empty() {
725 None
726 } else {
727 Some(err.server.clone())
728 },
729 procedure: if err.procedure.is_empty() {
730 None
731 } else {
732 Some(err.procedure.clone())
733 },
734 line: err.line as u32,
735 });
736 }
737 Token::Info(info) => {
738 tracing::info!(
739 number = info.number,
740 message = %info.message,
741 "server info message"
742 );
743 }
744 Token::Done(done) => {
745 if done.status.error {
746 return Err(Error::Protocol("login failed".to_string()));
747 }
748 break;
749 }
750 _ => {}
751 }
752 }
753
754 Ok((server_version, database, routing))
755 }
756
757 fn process_env_change(
759 env: &EnvChange,
760 database: &mut Option<String>,
761 routing: &mut Option<(String, u16)>,
762 ) {
763 use tds_protocol::token::EnvChangeValue;
764
765 match env.env_type {
766 EnvChangeType::Database => {
767 if let EnvChangeValue::String(ref new_value) = env.new_value {
768 tracing::debug!(database = %new_value, "database changed");
769 *database = Some(new_value.clone());
770 }
771 }
772 EnvChangeType::Routing => {
773 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
774 tracing::info!(host = %host, port = port, "routing redirect received");
775 *routing = Some((host.clone(), port));
776 }
777 }
778 _ => {
779 if let EnvChangeValue::String(ref new_value) = env.new_value {
780 tracing::debug!(
781 env_type = ?env.env_type,
782 new_value = %new_value,
783 "environment change"
784 );
785 }
786 }
787 }
788 }
789}
790
791impl<S: ConnectionState> Client<S> {
793 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
799 let payload =
800 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
801 let max_packet = self.config.packet_size as usize;
802
803 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
804
805 match connection {
806 ConnectionHandle::Tls(conn) => {
807 conn.send_message(PacketType::SqlBatch, payload, max_packet)
808 .await
809 .map_err(|e| Error::Protocol(e.to_string()))?;
810 }
811 ConnectionHandle::TlsPrelogin(conn) => {
812 conn.send_message(PacketType::SqlBatch, payload, max_packet)
813 .await
814 .map_err(|e| Error::Protocol(e.to_string()))?;
815 }
816 ConnectionHandle::Plain(conn) => {
817 conn.send_message(PacketType::SqlBatch, payload, max_packet)
818 .await
819 .map_err(|e| Error::Protocol(e.to_string()))?;
820 }
821 }
822
823 Ok(())
824 }
825
826 async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
830 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
831 let max_packet = self.config.packet_size as usize;
832
833 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
834
835 match connection {
836 ConnectionHandle::Tls(conn) => {
837 conn.send_message(PacketType::Rpc, payload, max_packet)
838 .await
839 .map_err(|e| Error::Protocol(e.to_string()))?;
840 }
841 ConnectionHandle::TlsPrelogin(conn) => {
842 conn.send_message(PacketType::Rpc, payload, max_packet)
843 .await
844 .map_err(|e| Error::Protocol(e.to_string()))?;
845 }
846 ConnectionHandle::Plain(conn) => {
847 conn.send_message(PacketType::Rpc, payload, max_packet)
848 .await
849 .map_err(|e| Error::Protocol(e.to_string()))?;
850 }
851 }
852
853 Ok(())
854 }
855
856 fn convert_params(params: &[&(dyn crate::ToSql + Sync)]) -> Result<Vec<RpcParam>> {
858 use bytes::{BufMut, BytesMut};
859 use mssql_types::SqlValue;
860
861 params
862 .iter()
863 .enumerate()
864 .map(|(i, p)| {
865 let sql_value = p.to_sql()?;
866 let name = format!("@p{}", i + 1);
867
868 Ok(match sql_value {
869 SqlValue::Null => RpcParam::null(&name, RpcTypeInfo::nvarchar(1)),
870 SqlValue::Bool(v) => {
871 let mut buf = BytesMut::with_capacity(1);
872 buf.put_u8(if v { 1 } else { 0 });
873 RpcParam::new(&name, RpcTypeInfo::bit(), buf.freeze())
874 }
875 SqlValue::TinyInt(v) => {
876 let mut buf = BytesMut::with_capacity(1);
877 buf.put_u8(v);
878 RpcParam::new(&name, RpcTypeInfo::tinyint(), buf.freeze())
879 }
880 SqlValue::SmallInt(v) => {
881 let mut buf = BytesMut::with_capacity(2);
882 buf.put_i16_le(v);
883 RpcParam::new(&name, RpcTypeInfo::smallint(), buf.freeze())
884 }
885 SqlValue::Int(v) => RpcParam::int(&name, v),
886 SqlValue::BigInt(v) => RpcParam::bigint(&name, v),
887 SqlValue::Float(v) => {
888 let mut buf = BytesMut::with_capacity(4);
889 buf.put_f32_le(v);
890 RpcParam::new(&name, RpcTypeInfo::real(), buf.freeze())
891 }
892 SqlValue::Double(v) => {
893 let mut buf = BytesMut::with_capacity(8);
894 buf.put_f64_le(v);
895 RpcParam::new(&name, RpcTypeInfo::float(), buf.freeze())
896 }
897 SqlValue::String(ref s) => RpcParam::nvarchar(&name, s),
898 SqlValue::Binary(ref b) => {
899 RpcParam::new(&name, RpcTypeInfo::varbinary(b.len() as u16), b.clone())
900 }
901 SqlValue::Xml(ref s) => RpcParam::nvarchar(&name, s),
902 #[cfg(feature = "uuid")]
903 SqlValue::Uuid(u) => {
904 let bytes = u.as_bytes();
906 let mut buf = BytesMut::with_capacity(16);
907 buf.put_u32_le(u32::from_be_bytes([
909 bytes[0], bytes[1], bytes[2], bytes[3],
910 ]));
911 buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
912 buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
913 buf.put_slice(&bytes[8..16]);
914 RpcParam::new(&name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
915 }
916 #[cfg(feature = "decimal")]
917 SqlValue::Decimal(d) => {
918 RpcParam::nvarchar(&name, &d.to_string())
920 }
921 #[cfg(feature = "chrono")]
922 SqlValue::Date(_)
923 | SqlValue::Time(_)
924 | SqlValue::DateTime(_)
925 | SqlValue::DateTimeOffset(_) => {
926 let s = match &sql_value {
929 SqlValue::Date(d) => d.to_string(),
930 SqlValue::Time(t) => t.to_string(),
931 SqlValue::DateTime(dt) => dt.to_string(),
932 SqlValue::DateTimeOffset(dto) => dto.to_rfc3339(),
933 _ => unreachable!(),
934 };
935 RpcParam::nvarchar(&name, &s)
936 }
937 #[cfg(feature = "json")]
938 SqlValue::Json(ref j) => RpcParam::nvarchar(&name, &j.to_string()),
939 })
940 })
941 .collect()
942 }
943
944 async fn read_query_response(
946 &mut self,
947 ) -> Result<(Vec<crate::row::Column>, Vec<crate::row::Row>)> {
948 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
949
950 let message = match connection {
951 ConnectionHandle::Tls(conn) => conn
952 .read_message()
953 .await
954 .map_err(|e| Error::Protocol(e.to_string()))?,
955 ConnectionHandle::TlsPrelogin(conn) => conn
956 .read_message()
957 .await
958 .map_err(|e| Error::Protocol(e.to_string()))?,
959 ConnectionHandle::Plain(conn) => conn
960 .read_message()
961 .await
962 .map_err(|e| Error::Protocol(e.to_string()))?,
963 }
964 .ok_or(Error::ConnectionClosed)?;
965
966 let mut parser = TokenParser::new(message.payload);
967 let mut columns: Vec<crate::row::Column> = Vec::new();
968 let mut rows: Vec<crate::row::Row> = Vec::new();
969 let mut protocol_metadata: Option<ColMetaData> = None;
970
971 loop {
972 let token = parser
974 .next_token_with_metadata(protocol_metadata.as_ref())
975 .map_err(|e| Error::Protocol(e.to_string()))?;
976
977 let Some(token) = token else {
978 break;
979 };
980
981 match token {
982 Token::ColMetaData(meta) => {
983 columns = meta
984 .columns
985 .iter()
986 .enumerate()
987 .map(|(i, col)| {
988 let type_name = format!("{:?}", col.type_id);
989 let mut column = crate::row::Column::new(&col.name, i, type_name)
990 .with_nullable(col.flags & 0x01 != 0);
991
992 if let Some(max_len) = col.type_info.max_length {
993 column = column.with_max_length(max_len);
994 }
995 if let (Some(prec), Some(scale)) =
996 (col.type_info.precision, col.type_info.scale)
997 {
998 column = column.with_precision_scale(prec, scale);
999 }
1000 column
1001 })
1002 .collect();
1003
1004 tracing::debug!(columns = columns.len(), "received column metadata");
1005 protocol_metadata = Some(meta);
1006 }
1007 Token::Row(raw_row) => {
1008 if let Some(ref meta) = protocol_metadata {
1009 let row = Self::convert_raw_row(&raw_row, meta, &columns)?;
1010 rows.push(row);
1011 }
1012 }
1013 Token::NbcRow(nbc_row) => {
1014 if let Some(ref meta) = protocol_metadata {
1015 let row = Self::convert_nbc_row(&nbc_row, meta, &columns)?;
1016 rows.push(row);
1017 }
1018 }
1019 Token::Error(err) => {
1020 return Err(Error::Server {
1021 number: err.number,
1022 state: err.state,
1023 class: err.class,
1024 message: err.message.clone(),
1025 server: if err.server.is_empty() {
1026 None
1027 } else {
1028 Some(err.server.clone())
1029 },
1030 procedure: if err.procedure.is_empty() {
1031 None
1032 } else {
1033 Some(err.procedure.clone())
1034 },
1035 line: err.line as u32,
1036 });
1037 }
1038 Token::Done(done) => {
1039 if done.status.error {
1040 return Err(Error::Query("query failed".to_string()));
1041 }
1042 tracing::debug!(
1043 row_count = done.row_count,
1044 has_more = done.status.more,
1045 "query complete"
1046 );
1047 break;
1048 }
1049 Token::DoneProc(done) => {
1050 if done.status.error {
1051 return Err(Error::Query("query failed".to_string()));
1052 }
1053 }
1054 Token::DoneInProc(done) => {
1055 if done.status.error {
1056 return Err(Error::Query("query failed".to_string()));
1057 }
1058 }
1059 Token::Info(info) => {
1060 tracing::debug!(
1061 number = info.number,
1062 message = %info.message,
1063 "server info message"
1064 );
1065 }
1066 _ => {}
1067 }
1068 }
1069
1070 tracing::debug!(
1071 columns = columns.len(),
1072 rows = rows.len(),
1073 "query response parsed"
1074 );
1075 Ok((columns, rows))
1076 }
1077
1078 fn convert_raw_row(
1082 raw: &RawRow,
1083 meta: &ColMetaData,
1084 columns: &[crate::row::Column],
1085 ) -> Result<crate::row::Row> {
1086 let mut values = Vec::with_capacity(meta.columns.len());
1087 let mut buf = raw.data.as_ref();
1088
1089 for col in &meta.columns {
1090 let value = Self::parse_column_value(&mut buf, col)?;
1091 values.push(value);
1092 }
1093
1094 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1095 }
1096
1097 fn convert_nbc_row(
1101 nbc: &NbcRow,
1102 meta: &ColMetaData,
1103 columns: &[crate::row::Column],
1104 ) -> Result<crate::row::Row> {
1105 let mut values = Vec::with_capacity(meta.columns.len());
1106 let mut buf = nbc.data.as_ref();
1107
1108 for (i, col) in meta.columns.iter().enumerate() {
1109 if nbc.is_null(i) {
1110 values.push(mssql_types::SqlValue::Null);
1111 } else {
1112 let value = Self::parse_column_value(&mut buf, col)?;
1113 values.push(value);
1114 }
1115 }
1116
1117 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1118 }
1119
1120 fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<mssql_types::SqlValue> {
1122 use bytes::Buf;
1123 use mssql_types::SqlValue;
1124 use tds_protocol::types::TypeId;
1125
1126 let value = match col.type_id {
1127 TypeId::Null => SqlValue::Null,
1129
1130 TypeId::Int1 => {
1132 if buf.remaining() < 1 {
1133 return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
1134 }
1135 SqlValue::TinyInt(buf.get_u8())
1136 }
1137 TypeId::Bit => {
1138 if buf.remaining() < 1 {
1139 return Err(Error::Protocol("unexpected EOF reading BIT".into()));
1140 }
1141 SqlValue::Bool(buf.get_u8() != 0)
1142 }
1143
1144 TypeId::Int2 => {
1146 if buf.remaining() < 2 {
1147 return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
1148 }
1149 SqlValue::SmallInt(buf.get_i16_le())
1150 }
1151
1152 TypeId::Int4 => {
1154 if buf.remaining() < 4 {
1155 return Err(Error::Protocol("unexpected EOF reading INT".into()));
1156 }
1157 SqlValue::Int(buf.get_i32_le())
1158 }
1159 TypeId::Float4 => {
1160 if buf.remaining() < 4 {
1161 return Err(Error::Protocol("unexpected EOF reading REAL".into()));
1162 }
1163 SqlValue::Float(buf.get_f32_le())
1164 }
1165
1166 TypeId::Int8 => {
1168 if buf.remaining() < 8 {
1169 return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
1170 }
1171 SqlValue::BigInt(buf.get_i64_le())
1172 }
1173 TypeId::Float8 => {
1174 if buf.remaining() < 8 {
1175 return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
1176 }
1177 SqlValue::Double(buf.get_f64_le())
1178 }
1179 TypeId::Money => {
1180 if buf.remaining() < 8 {
1181 return Err(Error::Protocol("unexpected EOF reading MONEY".into()));
1182 }
1183 let high = buf.get_i32_le();
1185 let low = buf.get_u32_le();
1186 let cents = ((high as i64) << 32) | (low as i64);
1187 let value = (cents as f64) / 10000.0;
1188 SqlValue::Double(value)
1189 }
1190 TypeId::Money4 => {
1191 if buf.remaining() < 4 {
1192 return Err(Error::Protocol("unexpected EOF reading SMALLMONEY".into()));
1193 }
1194 let cents = buf.get_i32_le();
1195 let value = (cents as f64) / 10000.0;
1196 SqlValue::Double(value)
1197 }
1198
1199 TypeId::IntN => {
1201 if buf.remaining() < 1 {
1202 return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
1203 }
1204 let len = buf.get_u8();
1205 match len {
1206 0 => SqlValue::Null,
1207 1 => SqlValue::TinyInt(buf.get_u8()),
1208 2 => SqlValue::SmallInt(buf.get_i16_le()),
1209 4 => SqlValue::Int(buf.get_i32_le()),
1210 8 => SqlValue::BigInt(buf.get_i64_le()),
1211 _ => {
1212 return Err(Error::Protocol(format!("invalid IntN length: {len}")));
1213 }
1214 }
1215 }
1216 TypeId::FloatN => {
1217 if buf.remaining() < 1 {
1218 return Err(Error::Protocol(
1219 "unexpected EOF reading FloatN length".into(),
1220 ));
1221 }
1222 let len = buf.get_u8();
1223 match len {
1224 0 => SqlValue::Null,
1225 4 => SqlValue::Float(buf.get_f32_le()),
1226 8 => SqlValue::Double(buf.get_f64_le()),
1227 _ => {
1228 return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
1229 }
1230 }
1231 }
1232 TypeId::BitN => {
1233 if buf.remaining() < 1 {
1234 return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
1235 }
1236 let len = buf.get_u8();
1237 match len {
1238 0 => SqlValue::Null,
1239 1 => SqlValue::Bool(buf.get_u8() != 0),
1240 _ => {
1241 return Err(Error::Protocol(format!("invalid BitN length: {len}")));
1242 }
1243 }
1244 }
1245 TypeId::MoneyN => {
1246 if buf.remaining() < 1 {
1247 return Err(Error::Protocol(
1248 "unexpected EOF reading MoneyN length".into(),
1249 ));
1250 }
1251 let len = buf.get_u8();
1252 match len {
1253 0 => SqlValue::Null,
1254 4 => {
1255 let cents = buf.get_i32_le();
1256 SqlValue::Double((cents as f64) / 10000.0)
1257 }
1258 8 => {
1259 let high = buf.get_i32_le();
1260 let low = buf.get_u32_le();
1261 let cents = ((high as i64) << 32) | (low as i64);
1262 SqlValue::Double((cents as f64) / 10000.0)
1263 }
1264 _ => {
1265 return Err(Error::Protocol(format!("invalid MoneyN length: {len}")));
1266 }
1267 }
1268 }
1269 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1271 if buf.remaining() < 1 {
1272 return Err(Error::Protocol(
1273 "unexpected EOF reading DECIMAL/NUMERIC length".into(),
1274 ));
1275 }
1276 let len = buf.get_u8() as usize;
1277 if len == 0 {
1278 SqlValue::Null
1279 } else {
1280 if buf.remaining() < len {
1281 return Err(Error::Protocol(
1282 "unexpected EOF reading DECIMAL/NUMERIC data".into(),
1283 ));
1284 }
1285
1286 let sign = buf.get_u8();
1288 let mantissa_len = len - 1;
1289
1290 let mut mantissa_bytes = [0u8; 16];
1292 for i in 0..mantissa_len.min(16) {
1293 mantissa_bytes[i] = buf.get_u8();
1294 }
1295 for _ in 16..mantissa_len {
1297 buf.get_u8();
1298 }
1299
1300 let mantissa = u128::from_le_bytes(mantissa_bytes);
1301 let scale = col.type_info.scale.unwrap_or(0) as u32;
1302
1303 #[cfg(feature = "decimal")]
1304 {
1305 use rust_decimal::Decimal;
1306 let mut decimal = Decimal::from_i128_with_scale(mantissa as i128, scale);
1307 if sign == 0 {
1308 decimal.set_sign_negative(true);
1309 }
1310 SqlValue::Decimal(decimal)
1311 }
1312
1313 #[cfg(not(feature = "decimal"))]
1314 {
1315 let divisor = 10f64.powi(scale as i32);
1317 let value = (mantissa as f64) / divisor;
1318 let value = if sign == 0 { -value } else { value };
1319 SqlValue::Double(value)
1320 }
1321 }
1322 }
1323
1324 TypeId::DateTimeN => {
1326 if buf.remaining() < 1 {
1327 return Err(Error::Protocol(
1328 "unexpected EOF reading DateTimeN length".into(),
1329 ));
1330 }
1331 let len = buf.get_u8() as usize;
1332 if len == 0 {
1333 SqlValue::Null
1334 } else if buf.remaining() < len {
1335 return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
1336 } else {
1337 match len {
1338 4 => {
1339 let days = buf.get_u16_le() as i64;
1341 let minutes = buf.get_u16_le() as u32;
1342 #[cfg(feature = "chrono")]
1343 {
1344 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1345 let date = base + chrono::Duration::days(days);
1346 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1347 minutes * 60,
1348 0,
1349 )
1350 .unwrap();
1351 SqlValue::DateTime(date.and_time(time))
1352 }
1353 #[cfg(not(feature = "chrono"))]
1354 {
1355 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1356 }
1357 }
1358 8 => {
1359 let days = buf.get_i32_le() as i64;
1361 let time_300ths = buf.get_u32_le() as u64;
1362 #[cfg(feature = "chrono")]
1363 {
1364 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1365 let date = base + chrono::Duration::days(days);
1366 let total_ms = (time_300ths * 1000) / 300;
1368 let secs = (total_ms / 1000) as u32;
1369 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1370 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1371 secs, nanos,
1372 )
1373 .unwrap();
1374 SqlValue::DateTime(date.and_time(time))
1375 }
1376 #[cfg(not(feature = "chrono"))]
1377 {
1378 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1379 }
1380 }
1381 _ => {
1382 return Err(Error::Protocol(format!(
1383 "invalid DateTimeN length: {len}"
1384 )));
1385 }
1386 }
1387 }
1388 }
1389
1390 TypeId::DateTime => {
1392 if buf.remaining() < 8 {
1393 return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
1394 }
1395 let days = buf.get_i32_le() as i64;
1396 let time_300ths = buf.get_u32_le() as u64;
1397 #[cfg(feature = "chrono")]
1398 {
1399 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1400 let date = base + chrono::Duration::days(days);
1401 let total_ms = (time_300ths * 1000) / 300;
1402 let secs = (total_ms / 1000) as u32;
1403 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1404 let time =
1405 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).unwrap();
1406 SqlValue::DateTime(date.and_time(time))
1407 }
1408 #[cfg(not(feature = "chrono"))]
1409 {
1410 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1411 }
1412 }
1413
1414 TypeId::DateTime4 => {
1416 if buf.remaining() < 4 {
1417 return Err(Error::Protocol(
1418 "unexpected EOF reading SMALLDATETIME".into(),
1419 ));
1420 }
1421 let days = buf.get_u16_le() as i64;
1422 let minutes = buf.get_u16_le() as u32;
1423 #[cfg(feature = "chrono")]
1424 {
1425 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1426 let date = base + chrono::Duration::days(days);
1427 let time =
1428 chrono::NaiveTime::from_num_seconds_from_midnight_opt(minutes * 60, 0)
1429 .unwrap();
1430 SqlValue::DateTime(date.and_time(time))
1431 }
1432 #[cfg(not(feature = "chrono"))]
1433 {
1434 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1435 }
1436 }
1437
1438 TypeId::Date => {
1440 if buf.remaining() < 1 {
1441 return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
1442 }
1443 let len = buf.get_u8() as usize;
1444 if len == 0 {
1445 SqlValue::Null
1446 } else if len != 3 {
1447 return Err(Error::Protocol(format!("invalid DATE length: {len}")));
1448 } else if buf.remaining() < 3 {
1449 return Err(Error::Protocol("unexpected EOF reading DATE".into()));
1450 } else {
1451 let days = buf.get_u8() as u32
1453 | ((buf.get_u8() as u32) << 8)
1454 | ((buf.get_u8() as u32) << 16);
1455 #[cfg(feature = "chrono")]
1456 {
1457 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1458 let date = base + chrono::Duration::days(days as i64);
1459 SqlValue::Date(date)
1460 }
1461 #[cfg(not(feature = "chrono"))]
1462 {
1463 SqlValue::String(format!("DATE({days})"))
1464 }
1465 }
1466 }
1467
1468 TypeId::Time => {
1470 if buf.remaining() < 1 {
1471 return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
1472 }
1473 let len = buf.get_u8() as usize;
1474 if len == 0 {
1475 SqlValue::Null
1476 } else if buf.remaining() < len {
1477 return Err(Error::Protocol("unexpected EOF reading TIME".into()));
1478 } else {
1479 let scale = col.type_info.scale.unwrap_or(7);
1480 let mut time_bytes = [0u8; 8];
1481 for byte in time_bytes.iter_mut().take(len) {
1482 *byte = buf.get_u8();
1483 }
1484 let intervals = u64::from_le_bytes(time_bytes);
1485 #[cfg(feature = "chrono")]
1486 {
1487 let time = Self::intervals_to_time(intervals, scale);
1488 SqlValue::Time(time)
1489 }
1490 #[cfg(not(feature = "chrono"))]
1491 {
1492 SqlValue::String(format!("TIME({intervals})"))
1493 }
1494 }
1495 }
1496
1497 TypeId::DateTime2 => {
1499 if buf.remaining() < 1 {
1500 return Err(Error::Protocol(
1501 "unexpected EOF reading DATETIME2 length".into(),
1502 ));
1503 }
1504 let len = buf.get_u8() as usize;
1505 if len == 0 {
1506 SqlValue::Null
1507 } else if buf.remaining() < len {
1508 return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
1509 } else {
1510 let scale = col.type_info.scale.unwrap_or(7);
1511 let time_len = Self::time_bytes_for_scale(scale);
1512
1513 let mut time_bytes = [0u8; 8];
1515 for byte in time_bytes.iter_mut().take(time_len) {
1516 *byte = buf.get_u8();
1517 }
1518 let intervals = u64::from_le_bytes(time_bytes);
1519
1520 let days = buf.get_u8() as u32
1522 | ((buf.get_u8() as u32) << 8)
1523 | ((buf.get_u8() as u32) << 16);
1524
1525 #[cfg(feature = "chrono")]
1526 {
1527 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1528 let date = base + chrono::Duration::days(days as i64);
1529 let time = Self::intervals_to_time(intervals, scale);
1530 SqlValue::DateTime(date.and_time(time))
1531 }
1532 #[cfg(not(feature = "chrono"))]
1533 {
1534 SqlValue::String(format!("DATETIME2({days},{intervals})"))
1535 }
1536 }
1537 }
1538
1539 TypeId::DateTimeOffset => {
1541 if buf.remaining() < 1 {
1542 return Err(Error::Protocol(
1543 "unexpected EOF reading DATETIMEOFFSET length".into(),
1544 ));
1545 }
1546 let len = buf.get_u8() as usize;
1547 if len == 0 {
1548 SqlValue::Null
1549 } else if buf.remaining() < len {
1550 return Err(Error::Protocol(
1551 "unexpected EOF reading DATETIMEOFFSET".into(),
1552 ));
1553 } else {
1554 let scale = col.type_info.scale.unwrap_or(7);
1555 let time_len = Self::time_bytes_for_scale(scale);
1556
1557 let mut time_bytes = [0u8; 8];
1559 for byte in time_bytes.iter_mut().take(time_len) {
1560 *byte = buf.get_u8();
1561 }
1562 let intervals = u64::from_le_bytes(time_bytes);
1563
1564 let days = buf.get_u8() as u32
1566 | ((buf.get_u8() as u32) << 8)
1567 | ((buf.get_u8() as u32) << 16);
1568
1569 let offset_minutes = buf.get_i16_le();
1571
1572 #[cfg(feature = "chrono")]
1573 {
1574 use chrono::TimeZone;
1575 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1576 let date = base + chrono::Duration::days(days as i64);
1577 let time = Self::intervals_to_time(intervals, scale);
1578 let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
1579 .unwrap_or_else(|| chrono::FixedOffset::east_opt(0).unwrap());
1580 let datetime = offset
1581 .from_local_datetime(&date.and_time(time))
1582 .single()
1583 .unwrap_or_else(|| offset.from_utc_datetime(&date.and_time(time)));
1584 SqlValue::DateTimeOffset(datetime)
1585 }
1586 #[cfg(not(feature = "chrono"))]
1587 {
1588 SqlValue::String(format!(
1589 "DATETIMEOFFSET({days},{intervals},{offset_minutes})"
1590 ))
1591 }
1592 }
1593 }
1594
1595 TypeId::BigVarChar | TypeId::BigChar | TypeId::Text => {
1597 if buf.remaining() < 2 {
1599 return Err(Error::Protocol(
1600 "unexpected EOF reading varchar length".into(),
1601 ));
1602 }
1603 let len = buf.get_u16_le();
1604 if len == 0xFFFF {
1605 SqlValue::Null
1606 } else if buf.remaining() < len as usize {
1607 return Err(Error::Protocol(
1608 "unexpected EOF reading varchar data".into(),
1609 ));
1610 } else {
1611 let data = &buf[..len as usize];
1612 let s = String::from_utf8_lossy(data).into_owned();
1613 buf.advance(len as usize);
1614 SqlValue::String(s)
1615 }
1616 }
1617 TypeId::NVarChar | TypeId::NChar | TypeId::NText => {
1618 if buf.remaining() < 2 {
1620 return Err(Error::Protocol(
1621 "unexpected EOF reading nvarchar length".into(),
1622 ));
1623 }
1624 let len = buf.get_u16_le();
1625 if len == 0xFFFF {
1626 SqlValue::Null
1627 } else if buf.remaining() < len as usize {
1628 return Err(Error::Protocol(
1629 "unexpected EOF reading nvarchar data".into(),
1630 ));
1631 } else {
1632 let data = &buf[..len as usize];
1633 let utf16: Vec<u16> = data
1635 .chunks_exact(2)
1636 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
1637 .collect();
1638 let s = String::from_utf16(&utf16)
1639 .map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
1640 buf.advance(len as usize);
1641 SqlValue::String(s)
1642 }
1643 }
1644
1645 TypeId::BigVarBinary | TypeId::BigBinary | TypeId::Image => {
1647 if buf.remaining() < 2 {
1648 return Err(Error::Protocol(
1649 "unexpected EOF reading varbinary length".into(),
1650 ));
1651 }
1652 let len = buf.get_u16_le();
1653 if len == 0xFFFF {
1654 SqlValue::Null
1655 } else if buf.remaining() < len as usize {
1656 return Err(Error::Protocol(
1657 "unexpected EOF reading varbinary data".into(),
1658 ));
1659 } else {
1660 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
1661 buf.advance(len as usize);
1662 SqlValue::Binary(data)
1663 }
1664 }
1665
1666 TypeId::Guid => {
1668 if buf.remaining() < 1 {
1669 return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
1670 }
1671 let len = buf.get_u8();
1672 if len == 0 {
1673 SqlValue::Null
1674 } else if len != 16 {
1675 return Err(Error::Protocol(format!("invalid GUID length: {len}")));
1676 } else if buf.remaining() < 16 {
1677 return Err(Error::Protocol("unexpected EOF reading GUID".into()));
1678 } else {
1679 let data = bytes::Bytes::copy_from_slice(&buf[..16]);
1681 buf.advance(16);
1682 SqlValue::Binary(data)
1683 }
1684 }
1685
1686 _ => {
1688 if buf.remaining() < 2 {
1690 return Err(Error::Protocol(format!(
1691 "unexpected EOF reading {:?}",
1692 col.type_id
1693 )));
1694 }
1695 let len = buf.get_u16_le();
1696 if len == 0xFFFF {
1697 SqlValue::Null
1698 } else if buf.remaining() < len as usize {
1699 return Err(Error::Protocol(format!(
1700 "unexpected EOF reading {:?} data",
1701 col.type_id
1702 )));
1703 } else {
1704 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
1705 buf.advance(len as usize);
1706 SqlValue::Binary(data)
1707 }
1708 }
1709 };
1710
1711 Ok(value)
1712 }
1713
1714 fn time_bytes_for_scale(scale: u8) -> usize {
1716 match scale {
1717 0..=2 => 3,
1718 3..=4 => 4,
1719 5..=7 => 5,
1720 _ => 5, }
1722 }
1723
1724 #[cfg(feature = "chrono")]
1726 fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
1727 let nanos = match scale {
1737 0 => intervals * 1_000_000_000,
1738 1 => intervals * 100_000_000,
1739 2 => intervals * 10_000_000,
1740 3 => intervals * 1_000_000,
1741 4 => intervals * 100_000,
1742 5 => intervals * 10_000,
1743 6 => intervals * 1_000,
1744 7 => intervals * 100,
1745 _ => intervals * 100,
1746 };
1747
1748 let secs = (nanos / 1_000_000_000) as u32;
1749 let nano_part = (nanos % 1_000_000_000) as u32;
1750
1751 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
1752 .unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap())
1753 }
1754
1755 async fn read_execute_result(&mut self) -> Result<u64> {
1757 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1758
1759 let message = match connection {
1760 ConnectionHandle::Tls(conn) => conn
1761 .read_message()
1762 .await
1763 .map_err(|e| Error::Protocol(e.to_string()))?,
1764 ConnectionHandle::TlsPrelogin(conn) => conn
1765 .read_message()
1766 .await
1767 .map_err(|e| Error::Protocol(e.to_string()))?,
1768 ConnectionHandle::Plain(conn) => conn
1769 .read_message()
1770 .await
1771 .map_err(|e| Error::Protocol(e.to_string()))?,
1772 }
1773 .ok_or(Error::ConnectionClosed)?;
1774
1775 let mut parser = TokenParser::new(message.payload);
1776 let mut rows_affected = 0u64;
1777 let mut current_metadata: Option<ColMetaData> = None;
1778
1779 loop {
1780 let token = parser
1782 .next_token_with_metadata(current_metadata.as_ref())
1783 .map_err(|e| Error::Protocol(e.to_string()))?;
1784
1785 let Some(token) = token else {
1786 break;
1787 };
1788
1789 match token {
1790 Token::ColMetaData(meta) => {
1791 current_metadata = Some(meta);
1793 }
1794 Token::Row(_) | Token::NbcRow(_) => {
1795 }
1798 Token::Done(done) => {
1799 if done.status.error {
1800 return Err(Error::Query("execution failed".to_string()));
1801 }
1802 if done.status.count {
1803 rows_affected = done.row_count;
1804 }
1805 break;
1806 }
1807 Token::DoneProc(done) => {
1808 if done.status.count {
1809 rows_affected = done.row_count;
1810 }
1811 }
1812 Token::DoneInProc(done) => {
1813 if done.status.count {
1814 rows_affected = done.row_count;
1815 }
1816 }
1817 Token::Error(err) => {
1818 return Err(Error::Server {
1819 number: err.number,
1820 state: err.state,
1821 class: err.class,
1822 message: err.message.clone(),
1823 server: if err.server.is_empty() {
1824 None
1825 } else {
1826 Some(err.server.clone())
1827 },
1828 procedure: if err.procedure.is_empty() {
1829 None
1830 } else {
1831 Some(err.procedure.clone())
1832 },
1833 line: err.line as u32,
1834 });
1835 }
1836 Token::Info(info) => {
1837 tracing::info!(
1838 number = info.number,
1839 message = %info.message,
1840 "server info message"
1841 );
1842 }
1843 _ => {}
1844 }
1845 }
1846
1847 Ok(rows_affected)
1848 }
1849
1850 async fn read_transaction_begin_result(&mut self) -> Result<u64> {
1856 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1857
1858 let message = match connection {
1859 ConnectionHandle::Tls(conn) => conn
1860 .read_message()
1861 .await
1862 .map_err(|e| Error::Protocol(e.to_string()))?,
1863 ConnectionHandle::TlsPrelogin(conn) => conn
1864 .read_message()
1865 .await
1866 .map_err(|e| Error::Protocol(e.to_string()))?,
1867 ConnectionHandle::Plain(conn) => conn
1868 .read_message()
1869 .await
1870 .map_err(|e| Error::Protocol(e.to_string()))?,
1871 }
1872 .ok_or(Error::ConnectionClosed)?;
1873
1874 let mut parser = TokenParser::new(message.payload);
1875 let mut transaction_descriptor: u64 = 0;
1876
1877 loop {
1878 let token = parser
1879 .next_token()
1880 .map_err(|e| Error::Protocol(e.to_string()))?;
1881
1882 let Some(token) = token else {
1883 break;
1884 };
1885
1886 match token {
1887 Token::EnvChange(env) => {
1888 if env.env_type == EnvChangeType::BeginTransaction {
1889 if let tds_protocol::token::EnvChangeValue::Binary(ref data) = env.new_value
1892 {
1893 if data.len() >= 8 {
1894 transaction_descriptor = u64::from_le_bytes([
1895 data[0], data[1], data[2], data[3], data[4], data[5], data[6],
1896 data[7],
1897 ]);
1898 tracing::debug!(
1899 transaction_descriptor =
1900 format!("0x{:016X}", transaction_descriptor),
1901 "transaction begun"
1902 );
1903 }
1904 }
1905 }
1906 }
1907 Token::Done(done) => {
1908 if done.status.error {
1909 return Err(Error::Query("BEGIN TRANSACTION failed".to_string()));
1910 }
1911 break;
1912 }
1913 Token::Error(err) => {
1914 return Err(Error::Server {
1915 number: err.number,
1916 state: err.state,
1917 class: err.class,
1918 message: err.message.clone(),
1919 server: if err.server.is_empty() {
1920 None
1921 } else {
1922 Some(err.server.clone())
1923 },
1924 procedure: if err.procedure.is_empty() {
1925 None
1926 } else {
1927 Some(err.procedure.clone())
1928 },
1929 line: err.line as u32,
1930 });
1931 }
1932 Token::Info(info) => {
1933 tracing::info!(
1934 number = info.number,
1935 message = %info.message,
1936 "server info message"
1937 );
1938 }
1939 _ => {}
1940 }
1941 }
1942
1943 Ok(transaction_descriptor)
1944 }
1945}
1946
1947impl Client<Ready> {
1948 pub async fn query<'a>(
1973 &'a mut self,
1974 sql: &str,
1975 params: &[&(dyn crate::ToSql + Sync)],
1976 ) -> Result<QueryStream<'a>> {
1977 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
1978
1979 #[cfg(feature = "otel")]
1980 let instrumentation = self.instrumentation.clone();
1981 #[cfg(feature = "otel")]
1982 let mut span = instrumentation.query_span(sql);
1983
1984 let result = async {
1985 if params.is_empty() {
1986 self.send_sql_batch(sql).await?;
1988 } else {
1989 let rpc_params = Self::convert_params(params)?;
1991 let rpc = RpcRequest::execute_sql(sql, rpc_params);
1992 self.send_rpc(&rpc).await?;
1993 }
1994
1995 self.read_query_response().await
1997 }
1998 .await;
1999
2000 #[cfg(feature = "otel")]
2001 match &result {
2002 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2003 Err(e) => InstrumentationContext::record_error(&mut span, e),
2004 }
2005
2006 #[cfg(feature = "otel")]
2008 drop(span);
2009
2010 let (columns, rows) = result?;
2011 Ok(QueryStream::new(columns, rows))
2012 }
2013
2014 pub async fn query_multiple<'a>(
2041 &'a mut self,
2042 sql: &str,
2043 params: &[&(dyn crate::ToSql + Sync)],
2044 ) -> Result<MultiResultStream<'a>> {
2045 tracing::debug!(
2046 sql = sql,
2047 params_count = params.len(),
2048 "executing multi-result query"
2049 );
2050
2051 if params.is_empty() {
2052 self.send_sql_batch(sql).await?;
2054 } else {
2055 let rpc_params = Self::convert_params(params)?;
2057 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2058 self.send_rpc(&rpc).await?;
2059 }
2060
2061 let result_sets = self.read_multi_result_response().await?;
2063 Ok(MultiResultStream::new(result_sets))
2064 }
2065
2066 async fn read_multi_result_response(&mut self) -> Result<Vec<crate::stream::ResultSet>> {
2068 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2069
2070 let message = match connection {
2071 ConnectionHandle::Tls(conn) => conn
2072 .read_message()
2073 .await
2074 .map_err(|e| Error::Protocol(e.to_string()))?,
2075 ConnectionHandle::TlsPrelogin(conn) => conn
2076 .read_message()
2077 .await
2078 .map_err(|e| Error::Protocol(e.to_string()))?,
2079 ConnectionHandle::Plain(conn) => conn
2080 .read_message()
2081 .await
2082 .map_err(|e| Error::Protocol(e.to_string()))?,
2083 }
2084 .ok_or(Error::ConnectionClosed)?;
2085
2086 let mut parser = TokenParser::new(message.payload);
2087 let mut result_sets: Vec<crate::stream::ResultSet> = Vec::new();
2088 let mut current_columns: Vec<crate::row::Column> = Vec::new();
2089 let mut current_rows: Vec<crate::row::Row> = Vec::new();
2090 let mut protocol_metadata: Option<ColMetaData> = None;
2091
2092 loop {
2093 let token = parser
2094 .next_token_with_metadata(protocol_metadata.as_ref())
2095 .map_err(|e| Error::Protocol(e.to_string()))?;
2096
2097 let Some(token) = token else {
2098 break;
2099 };
2100
2101 match token {
2102 Token::ColMetaData(meta) => {
2103 if !current_columns.is_empty() {
2105 result_sets.push(crate::stream::ResultSet::new(
2106 std::mem::take(&mut current_columns),
2107 std::mem::take(&mut current_rows),
2108 ));
2109 }
2110
2111 current_columns = meta
2113 .columns
2114 .iter()
2115 .enumerate()
2116 .map(|(i, col)| {
2117 let type_name = format!("{:?}", col.type_id);
2118 let mut column = crate::row::Column::new(&col.name, i, type_name)
2119 .with_nullable(col.flags & 0x01 != 0);
2120
2121 if let Some(max_len) = col.type_info.max_length {
2122 column = column.with_max_length(max_len);
2123 }
2124 if let (Some(prec), Some(scale)) =
2125 (col.type_info.precision, col.type_info.scale)
2126 {
2127 column = column.with_precision_scale(prec, scale);
2128 }
2129 column
2130 })
2131 .collect();
2132
2133 tracing::debug!(
2134 columns = current_columns.len(),
2135 result_set = result_sets.len(),
2136 "received column metadata for result set"
2137 );
2138 protocol_metadata = Some(meta);
2139 }
2140 Token::Row(raw_row) => {
2141 if let Some(ref meta) = protocol_metadata {
2142 let row = Self::convert_raw_row(&raw_row, meta, ¤t_columns)?;
2143 current_rows.push(row);
2144 }
2145 }
2146 Token::NbcRow(nbc_row) => {
2147 if let Some(ref meta) = protocol_metadata {
2148 let row = Self::convert_nbc_row(&nbc_row, meta, ¤t_columns)?;
2149 current_rows.push(row);
2150 }
2151 }
2152 Token::Error(err) => {
2153 return Err(Error::Server {
2154 number: err.number,
2155 state: err.state,
2156 class: err.class,
2157 message: err.message.clone(),
2158 server: if err.server.is_empty() {
2159 None
2160 } else {
2161 Some(err.server.clone())
2162 },
2163 procedure: if err.procedure.is_empty() {
2164 None
2165 } else {
2166 Some(err.procedure.clone())
2167 },
2168 line: err.line as u32,
2169 });
2170 }
2171 Token::Done(done) => {
2172 if done.status.error {
2173 return Err(Error::Query("query failed".to_string()));
2174 }
2175
2176 if !current_columns.is_empty() {
2178 result_sets.push(crate::stream::ResultSet::new(
2179 std::mem::take(&mut current_columns),
2180 std::mem::take(&mut current_rows),
2181 ));
2182 protocol_metadata = None;
2183 }
2184
2185 if !done.status.more {
2187 tracing::debug!(result_sets = result_sets.len(), "all result sets parsed");
2188 break;
2189 }
2190 }
2191 Token::DoneInProc(done) => {
2192 if done.status.error {
2193 return Err(Error::Query("query failed".to_string()));
2194 }
2195
2196 if !current_columns.is_empty() {
2198 result_sets.push(crate::stream::ResultSet::new(
2199 std::mem::take(&mut current_columns),
2200 std::mem::take(&mut current_rows),
2201 ));
2202 protocol_metadata = None;
2203 }
2204
2205 if !done.status.more {
2207 }
2209 }
2210 Token::DoneProc(done) => {
2211 if done.status.error {
2212 return Err(Error::Query("query failed".to_string()));
2213 }
2214 }
2216 Token::Info(info) => {
2217 tracing::debug!(
2218 number = info.number,
2219 message = %info.message,
2220 "server info message"
2221 );
2222 }
2223 _ => {}
2224 }
2225 }
2226
2227 if !current_columns.is_empty() {
2229 result_sets.push(crate::stream::ResultSet::new(current_columns, current_rows));
2230 }
2231
2232 Ok(result_sets)
2233 }
2234
2235 pub async fn execute(
2239 &mut self,
2240 sql: &str,
2241 params: &[&(dyn crate::ToSql + Sync)],
2242 ) -> Result<u64> {
2243 tracing::debug!(
2244 sql = sql,
2245 params_count = params.len(),
2246 "executing statement"
2247 );
2248
2249 #[cfg(feature = "otel")]
2250 let instrumentation = self.instrumentation.clone();
2251 #[cfg(feature = "otel")]
2252 let mut span = instrumentation.query_span(sql);
2253
2254 let result = async {
2255 if params.is_empty() {
2256 self.send_sql_batch(sql).await?;
2258 } else {
2259 let rpc_params = Self::convert_params(params)?;
2261 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2262 self.send_rpc(&rpc).await?;
2263 }
2264
2265 self.read_execute_result().await
2267 }
2268 .await;
2269
2270 #[cfg(feature = "otel")]
2271 match &result {
2272 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
2273 Err(e) => InstrumentationContext::record_error(&mut span, e),
2274 }
2275
2276 #[cfg(feature = "otel")]
2278 drop(span);
2279
2280 result
2281 }
2282
2283 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
2290 tracing::debug!("beginning transaction");
2291
2292 #[cfg(feature = "otel")]
2293 let instrumentation = self.instrumentation.clone();
2294 #[cfg(feature = "otel")]
2295 let mut span = instrumentation.transaction_span("BEGIN");
2296
2297 let result = async {
2299 self.send_sql_batch("BEGIN TRANSACTION").await?;
2300 self.read_transaction_begin_result().await
2301 }
2302 .await;
2303
2304 #[cfg(feature = "otel")]
2305 match &result {
2306 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2307 Err(e) => InstrumentationContext::record_error(&mut span, e),
2308 }
2309
2310 #[cfg(feature = "otel")]
2312 drop(span);
2313
2314 let transaction_descriptor = result?;
2315
2316 Ok(Client {
2317 config: self.config,
2318 _state: PhantomData,
2319 connection: self.connection,
2320 server_version: self.server_version,
2321 current_database: self.current_database,
2322 statement_cache: self.statement_cache,
2323 transaction_descriptor, #[cfg(feature = "otel")]
2325 instrumentation: self.instrumentation,
2326 })
2327 }
2328
2329 pub async fn begin_transaction_with_isolation(
2344 mut self,
2345 isolation_level: crate::transaction::IsolationLevel,
2346 ) -> Result<Client<InTransaction>> {
2347 tracing::debug!(
2348 isolation_level = %isolation_level.name(),
2349 "beginning transaction with isolation level"
2350 );
2351
2352 #[cfg(feature = "otel")]
2353 let instrumentation = self.instrumentation.clone();
2354 #[cfg(feature = "otel")]
2355 let mut span = instrumentation.transaction_span("BEGIN");
2356
2357 let result = async {
2359 self.send_sql_batch(isolation_level.as_sql()).await?;
2360 self.read_execute_result().await?;
2361
2362 self.send_sql_batch("BEGIN TRANSACTION").await?;
2364 self.read_transaction_begin_result().await
2365 }
2366 .await;
2367
2368 #[cfg(feature = "otel")]
2369 match &result {
2370 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2371 Err(e) => InstrumentationContext::record_error(&mut span, e),
2372 }
2373
2374 #[cfg(feature = "otel")]
2375 drop(span);
2376
2377 let transaction_descriptor = result?;
2378
2379 Ok(Client {
2380 config: self.config,
2381 _state: PhantomData,
2382 connection: self.connection,
2383 server_version: self.server_version,
2384 current_database: self.current_database,
2385 statement_cache: self.statement_cache,
2386 transaction_descriptor,
2387 #[cfg(feature = "otel")]
2388 instrumentation: self.instrumentation,
2389 })
2390 }
2391
2392 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
2397 tracing::debug!(sql = sql, "executing simple query");
2398
2399 self.send_sql_batch(sql).await?;
2401
2402 let _ = self.read_execute_result().await?;
2404
2405 Ok(())
2406 }
2407
2408 pub async fn close(self) -> Result<()> {
2410 tracing::debug!("closing connection");
2411 Ok(())
2412 }
2413
2414 #[must_use]
2416 pub fn database(&self) -> Option<&str> {
2417 self.config.database.as_deref()
2418 }
2419
2420 #[must_use]
2422 pub fn host(&self) -> &str {
2423 &self.config.host
2424 }
2425
2426 #[must_use]
2428 pub fn port(&self) -> u16 {
2429 self.config.port
2430 }
2431}
2432
2433impl Client<InTransaction> {
2434 pub async fn query<'a>(
2438 &'a mut self,
2439 sql: &str,
2440 params: &[&(dyn crate::ToSql + Sync)],
2441 ) -> Result<QueryStream<'a>> {
2442 tracing::debug!(
2443 sql = sql,
2444 params_count = params.len(),
2445 "executing query in transaction"
2446 );
2447
2448 #[cfg(feature = "otel")]
2449 let instrumentation = self.instrumentation.clone();
2450 #[cfg(feature = "otel")]
2451 let mut span = instrumentation.query_span(sql);
2452
2453 let result = async {
2454 if params.is_empty() {
2455 self.send_sql_batch(sql).await?;
2457 } else {
2458 let rpc_params = Self::convert_params(params)?;
2460 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2461 self.send_rpc(&rpc).await?;
2462 }
2463
2464 self.read_query_response().await
2466 }
2467 .await;
2468
2469 #[cfg(feature = "otel")]
2470 match &result {
2471 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2472 Err(e) => InstrumentationContext::record_error(&mut span, e),
2473 }
2474
2475 #[cfg(feature = "otel")]
2477 drop(span);
2478
2479 let (columns, rows) = result?;
2480 Ok(QueryStream::new(columns, rows))
2481 }
2482
2483 pub async fn execute(
2487 &mut self,
2488 sql: &str,
2489 params: &[&(dyn crate::ToSql + Sync)],
2490 ) -> Result<u64> {
2491 tracing::debug!(
2492 sql = sql,
2493 params_count = params.len(),
2494 "executing statement in transaction"
2495 );
2496
2497 #[cfg(feature = "otel")]
2498 let instrumentation = self.instrumentation.clone();
2499 #[cfg(feature = "otel")]
2500 let mut span = instrumentation.query_span(sql);
2501
2502 let result = async {
2503 if params.is_empty() {
2504 self.send_sql_batch(sql).await?;
2506 } else {
2507 let rpc_params = Self::convert_params(params)?;
2509 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2510 self.send_rpc(&rpc).await?;
2511 }
2512
2513 self.read_execute_result().await
2515 }
2516 .await;
2517
2518 #[cfg(feature = "otel")]
2519 match &result {
2520 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
2521 Err(e) => InstrumentationContext::record_error(&mut span, e),
2522 }
2523
2524 #[cfg(feature = "otel")]
2526 drop(span);
2527
2528 result
2529 }
2530
2531 pub async fn commit(mut self) -> Result<Client<Ready>> {
2535 tracing::debug!("committing transaction");
2536
2537 #[cfg(feature = "otel")]
2538 let instrumentation = self.instrumentation.clone();
2539 #[cfg(feature = "otel")]
2540 let mut span = instrumentation.transaction_span("COMMIT");
2541
2542 let result = async {
2544 self.send_sql_batch("COMMIT TRANSACTION").await?;
2545 self.read_execute_result().await
2546 }
2547 .await;
2548
2549 #[cfg(feature = "otel")]
2550 match &result {
2551 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2552 Err(e) => InstrumentationContext::record_error(&mut span, e),
2553 }
2554
2555 #[cfg(feature = "otel")]
2557 drop(span);
2558
2559 result?;
2560
2561 Ok(Client {
2562 config: self.config,
2563 _state: PhantomData,
2564 connection: self.connection,
2565 server_version: self.server_version,
2566 current_database: self.current_database,
2567 statement_cache: self.statement_cache,
2568 transaction_descriptor: 0, #[cfg(feature = "otel")]
2570 instrumentation: self.instrumentation,
2571 })
2572 }
2573
2574 pub async fn rollback(mut self) -> Result<Client<Ready>> {
2578 tracing::debug!("rolling back transaction");
2579
2580 #[cfg(feature = "otel")]
2581 let instrumentation = self.instrumentation.clone();
2582 #[cfg(feature = "otel")]
2583 let mut span = instrumentation.transaction_span("ROLLBACK");
2584
2585 let result = async {
2587 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
2588 self.read_execute_result().await
2589 }
2590 .await;
2591
2592 #[cfg(feature = "otel")]
2593 match &result {
2594 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2595 Err(e) => InstrumentationContext::record_error(&mut span, e),
2596 }
2597
2598 #[cfg(feature = "otel")]
2600 drop(span);
2601
2602 result?;
2603
2604 Ok(Client {
2605 config: self.config,
2606 _state: PhantomData,
2607 connection: self.connection,
2608 server_version: self.server_version,
2609 current_database: self.current_database,
2610 statement_cache: self.statement_cache,
2611 transaction_descriptor: 0, #[cfg(feature = "otel")]
2613 instrumentation: self.instrumentation,
2614 })
2615 }
2616
2617 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
2634 validate_identifier(name)?;
2635 tracing::debug!(name = name, "creating savepoint");
2636
2637 let sql = format!("SAVE TRANSACTION {}", name);
2640 self.send_sql_batch(&sql).await?;
2641 self.read_execute_result().await?;
2642
2643 Ok(SavePoint::new(name.to_string()))
2644 }
2645
2646 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
2661 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
2662
2663 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
2666 self.send_sql_batch(&sql).await?;
2667 self.read_execute_result().await?;
2668
2669 Ok(())
2670 }
2671
2672 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
2678 tracing::debug!(name = savepoint.name(), "releasing savepoint");
2679
2680 drop(savepoint);
2684 Ok(())
2685 }
2686}
2687
2688fn validate_identifier(name: &str) -> Result<()> {
2690 use once_cell::sync::Lazy;
2691 use regex::Regex;
2692
2693 static IDENTIFIER_RE: Lazy<Regex> =
2694 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
2695
2696 if name.is_empty() {
2697 return Err(Error::InvalidIdentifier(
2698 "identifier cannot be empty".into(),
2699 ));
2700 }
2701
2702 if !IDENTIFIER_RE.is_match(name) {
2703 return Err(Error::InvalidIdentifier(format!(
2704 "invalid identifier '{}': must start with letter/underscore, \
2705 contain only alphanumerics/_/@/#/$, and be 1-128 characters",
2706 name
2707 )));
2708 }
2709
2710 Ok(())
2711}
2712
2713impl<S: ConnectionState> std::fmt::Debug for Client<S> {
2714 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2715 f.debug_struct("Client")
2716 .field("host", &self.config.host)
2717 .field("port", &self.config.port)
2718 .field("database", &self.config.database)
2719 .finish()
2720 }
2721}
2722
2723#[cfg(test)]
2724#[allow(clippy::unwrap_used)]
2725mod tests {
2726 use super::*;
2727
2728 #[test]
2729 fn test_validate_identifier_valid() {
2730 assert!(validate_identifier("my_table").is_ok());
2731 assert!(validate_identifier("Table123").is_ok());
2732 assert!(validate_identifier("_private").is_ok());
2733 assert!(validate_identifier("sp_test").is_ok());
2734 }
2735
2736 #[test]
2737 fn test_validate_identifier_invalid() {
2738 assert!(validate_identifier("").is_err());
2739 assert!(validate_identifier("123abc").is_err());
2740 assert!(validate_identifier("table-name").is_err());
2741 assert!(validate_identifier("table name").is_err());
2742 assert!(validate_identifier("table;DROP TABLE users").is_err());
2743 }
2744}