1#![allow(clippy::unwrap_used, clippy::expect_used, clippy::needless_range_loop)]
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use bytes::BytesMut;
11use mssql_codec::connection::Connection;
12use mssql_tls::{TlsConfig, TlsConnector, TlsNegotiationMode, TlsStream};
13use tds_protocol::login7::Login7;
14use tds_protocol::packet::{MAX_PACKET_SIZE, PacketType};
15use tds_protocol::prelogin::{EncryptionLevel, PreLogin};
16use tds_protocol::rpc::{RpcParam, RpcRequest, TypeInfo as RpcTypeInfo};
17use tds_protocol::token::{
18 ColMetaData, ColumnData, EnvChange, EnvChangeType, NbcRow, RawRow, Token, TokenParser,
19};
20use tds_protocol::tvp::{
21 TvpColumnDef as TvpWireColumnDef, TvpColumnFlags, TvpEncoder, TvpWireType, encode_tvp_bit,
22 encode_tvp_decimal, encode_tvp_float, encode_tvp_int, encode_tvp_null, encode_tvp_nvarchar,
23 encode_tvp_varbinary,
24};
25use tokio::net::TcpStream;
26use tokio::time::timeout;
27
28use crate::config::Config;
29use crate::error::{Error, Result};
30#[cfg(feature = "otel")]
31use crate::instrumentation::InstrumentationContext;
32use crate::state::{ConnectionState, Disconnected, InTransaction, Ready};
33use crate::statement_cache::StatementCache;
34use crate::stream::{MultiResultStream, QueryStream};
35use crate::transaction::SavePoint;
36
37pub struct Client<S: ConnectionState> {
43 config: Config,
44 _state: PhantomData<S>,
45 connection: Option<ConnectionHandle>,
47 server_version: Option<u32>,
49 current_database: Option<String>,
51 statement_cache: StatementCache,
53 transaction_descriptor: u64,
57 #[cfg(feature = "otel")]
59 instrumentation: InstrumentationContext,
60}
61
62#[allow(dead_code)] enum ConnectionHandle {
70 Tls(Connection<TlsStream<TcpStream>>),
72 TlsPrelogin(Connection<TlsStream<mssql_tls::TlsPreloginWrapper<TcpStream>>>),
74 Plain(Connection<TcpStream>),
76}
77
78impl Client<Disconnected> {
79 pub async fn connect(config: Config) -> Result<Client<Ready>> {
90 let max_redirects = config.redirect.max_redirects;
91 let follow_redirects = config.redirect.follow_redirects;
92 let mut attempts = 0;
93 let mut current_config = config;
94
95 loop {
96 attempts += 1;
97 if attempts > max_redirects + 1 {
98 return Err(Error::TooManyRedirects { max: max_redirects });
99 }
100
101 match Self::try_connect(¤t_config).await {
102 Ok(client) => return Ok(client),
103 Err(Error::Routing { host, port }) => {
104 if !follow_redirects {
105 return Err(Error::Routing { host, port });
106 }
107 tracing::info!(
108 host = %host,
109 port = port,
110 attempt = attempts,
111 max_redirects = max_redirects,
112 "following Azure SQL routing redirect"
113 );
114 current_config = current_config.with_host(&host).with_port(port);
115 continue;
116 }
117 Err(e) => return Err(e),
118 }
119 }
120 }
121
122 async fn try_connect(config: &Config) -> Result<Client<Ready>> {
123 tracing::info!(
124 host = %config.host,
125 port = config.port,
126 database = ?config.database,
127 "connecting to SQL Server"
128 );
129
130 let addr = format!("{}:{}", config.host, config.port);
131
132 tracing::debug!("establishing TCP connection to {}", addr);
134 let tcp_stream = timeout(config.timeouts.connect_timeout, TcpStream::connect(&addr))
135 .await
136 .map_err(|_| Error::ConnectTimeout)?
137 .map_err(|e| Error::Io(Arc::new(e)))?;
138
139 tcp_stream
141 .set_nodelay(true)
142 .map_err(|e| Error::Io(Arc::new(e)))?;
143
144 let tls_mode = TlsNegotiationMode::from_encrypt_mode(config.strict_mode);
146
147 if tls_mode.is_tls_first() {
149 return Self::connect_tds_8(config, tcp_stream).await;
150 }
151
152 Self::connect_tds_7x(config, tcp_stream).await
154 }
155
156 async fn connect_tds_8(config: &Config, tcp_stream: TcpStream) -> Result<Client<Ready>> {
160 tracing::debug!("using TDS 8.0 strict mode (TLS first)");
161
162 let tls_config = TlsConfig::new()
164 .strict_mode(true)
165 .trust_server_certificate(config.trust_server_certificate);
166
167 let tls_connector = TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
168
169 let tls_stream = timeout(
171 config.timeouts.tls_timeout,
172 tls_connector.connect(tcp_stream, &config.host),
173 )
174 .await
175 .map_err(|_| Error::TlsTimeout)?
176 .map_err(|e| Error::Tls(e.to_string()))?;
177
178 tracing::debug!("TLS handshake completed (strict mode)");
179
180 let mut connection = Connection::new(tls_stream);
182
183 let prelogin = Self::build_prelogin(config, EncryptionLevel::Required);
185 Self::send_prelogin(&mut connection, &prelogin).await?;
186 let _prelogin_response = Self::receive_prelogin(&mut connection).await?;
187
188 let login = Self::build_login7(config);
190 Self::send_login7(&mut connection, &login).await?;
191
192 let (server_version, current_database, routing) =
194 Self::process_login_response(&mut connection).await?;
195
196 if let Some((host, port)) = routing {
198 return Err(Error::Routing { host, port });
199 }
200
201 Ok(Client {
202 config: config.clone(),
203 _state: PhantomData,
204 connection: Some(ConnectionHandle::Tls(connection)),
205 server_version,
206 current_database: current_database.clone(),
207 statement_cache: StatementCache::with_default_size(),
208 transaction_descriptor: 0, #[cfg(feature = "otel")]
210 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
211 .with_database(current_database.unwrap_or_default()),
212 })
213 }
214
215 async fn connect_tds_7x(config: &Config, mut tcp_stream: TcpStream) -> Result<Client<Ready>> {
223 use bytes::BufMut;
224 use tds_protocol::packet::{PACKET_HEADER_SIZE, PacketHeader, PacketStatus};
225 use tokio::io::{AsyncReadExt, AsyncWriteExt};
226
227 tracing::debug!("using TDS 7.x flow (PreLogin first)");
228
229 let client_encryption = if config.encrypt {
232 EncryptionLevel::On
233 } else {
234 EncryptionLevel::Off
235 };
236 let prelogin = Self::build_prelogin(config, client_encryption);
237 tracing::debug!(encryption = ?client_encryption, "sending PreLogin");
238 let prelogin_bytes = prelogin.encode();
239
240 let header = PacketHeader::new(
242 PacketType::PreLogin,
243 PacketStatus::END_OF_MESSAGE,
244 (PACKET_HEADER_SIZE + prelogin_bytes.len()) as u16,
245 );
246
247 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + prelogin_bytes.len());
248 header.encode(&mut packet_buf);
249 packet_buf.put_slice(&prelogin_bytes);
250
251 tcp_stream
252 .write_all(&packet_buf)
253 .await
254 .map_err(|e| Error::Io(Arc::new(e)))?;
255
256 let mut header_buf = [0u8; PACKET_HEADER_SIZE];
258 tcp_stream
259 .read_exact(&mut header_buf)
260 .await
261 .map_err(|e| Error::Io(Arc::new(e)))?;
262
263 let response_length = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
264 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
265
266 let mut response_buf = vec![0u8; payload_length];
267 tcp_stream
268 .read_exact(&mut response_buf)
269 .await
270 .map_err(|e| Error::Io(Arc::new(e)))?;
271
272 let prelogin_response =
273 PreLogin::decode(&response_buf[..]).map_err(|e| Error::Protocol(e.to_string()))?;
274
275 let server_encryption = prelogin_response.encryption;
277 tracing::debug!(encryption = ?server_encryption, "server encryption level");
278
279 let negotiated_encryption = match (client_encryption, server_encryption) {
285 (EncryptionLevel::NotSupported, EncryptionLevel::NotSupported) => {
286 EncryptionLevel::NotSupported
287 }
288 (EncryptionLevel::Off, EncryptionLevel::Off) => EncryptionLevel::Off,
289 (EncryptionLevel::On, EncryptionLevel::Off)
290 | (EncryptionLevel::On, EncryptionLevel::NotSupported) => {
291 return Err(Error::Protocol(
292 "Server does not support requested encryption level".to_string(),
293 ));
294 }
295 _ => EncryptionLevel::On,
296 };
297
298 let use_tls = negotiated_encryption != EncryptionLevel::NotSupported;
301
302 if use_tls {
303 let tls_config =
306 TlsConfig::new().trust_server_certificate(config.trust_server_certificate);
307
308 let tls_connector =
309 TlsConnector::new(tls_config).map_err(|e| Error::Tls(e.to_string()))?;
310
311 let mut tls_stream = timeout(
313 config.timeouts.tls_timeout,
314 tls_connector.connect_with_prelogin(tcp_stream, &config.host),
315 )
316 .await
317 .map_err(|_| Error::TlsTimeout)?
318 .map_err(|e| Error::Tls(e.to_string()))?;
319
320 tracing::debug!("TLS handshake completed (PreLogin wrapped)");
321
322 let login_only_encryption = negotiated_encryption == EncryptionLevel::Off;
324
325 if login_only_encryption {
326 use tokio::io::AsyncWriteExt;
334
335 let login = Self::build_login7(config);
337 let login_payload = login.encode();
338
339 let max_packet = MAX_PACKET_SIZE;
341 let max_payload = max_packet - PACKET_HEADER_SIZE;
342 let chunks: Vec<_> = login_payload.chunks(max_payload).collect();
343 let total_chunks = chunks.len();
344
345 for (i, chunk) in chunks.into_iter().enumerate() {
346 let is_last = i == total_chunks - 1;
347 let status = if is_last {
348 PacketStatus::END_OF_MESSAGE
349 } else {
350 PacketStatus::NORMAL
351 };
352
353 let header = PacketHeader::new(
354 PacketType::Tds7Login,
355 status,
356 (PACKET_HEADER_SIZE + chunk.len()) as u16,
357 );
358
359 let mut packet_buf = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk.len());
360 header.encode(&mut packet_buf);
361 packet_buf.put_slice(chunk);
362
363 tls_stream
364 .write_all(&packet_buf)
365 .await
366 .map_err(|e| Error::Io(Arc::new(e)))?;
367 }
368
369 tls_stream
371 .flush()
372 .await
373 .map_err(|e| Error::Io(Arc::new(e)))?;
374
375 tracing::debug!("Login7 sent through TLS, switching to plaintext for response");
376
377 let (wrapper, _client_conn) = tls_stream.into_inner();
381 let tcp_stream = wrapper.into_inner();
382
383 let mut connection = Connection::new(tcp_stream);
385
386 let (server_version, current_database, routing) =
388 Self::process_login_response(&mut connection).await?;
389
390 if let Some((host, port)) = routing {
392 return Err(Error::Routing { host, port });
393 }
394
395 Ok(Client {
397 config: config.clone(),
398 _state: PhantomData,
399 connection: Some(ConnectionHandle::Plain(connection)),
400 server_version,
401 current_database: current_database.clone(),
402 statement_cache: StatementCache::with_default_size(),
403 transaction_descriptor: 0, #[cfg(feature = "otel")]
405 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
406 .with_database(current_database.unwrap_or_default()),
407 })
408 } else {
409 let mut connection = Connection::new(tls_stream);
412
413 let login = Self::build_login7(config);
415 Self::send_login7(&mut connection, &login).await?;
416
417 let (server_version, current_database, routing) =
419 Self::process_login_response(&mut connection).await?;
420
421 if let Some((host, port)) = routing {
423 return Err(Error::Routing { host, port });
424 }
425
426 Ok(Client {
427 config: config.clone(),
428 _state: PhantomData,
429 connection: Some(ConnectionHandle::TlsPrelogin(connection)),
430 server_version,
431 current_database: current_database.clone(),
432 statement_cache: StatementCache::with_default_size(),
433 transaction_descriptor: 0, #[cfg(feature = "otel")]
435 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
436 .with_database(current_database.unwrap_or_default()),
437 })
438 }
439 } else {
440 tracing::warn!(
442 "Connecting without TLS encryption. This is insecure and should only be \
443 used for development/testing on trusted networks."
444 );
445
446 let login = Self::build_login7(config);
448 let login_bytes = login.encode();
449 tracing::debug!("Login7 packet built: {} bytes", login_bytes.len(),);
450 tracing::debug!(
452 "Login7 fixed header (94 bytes): {:02X?}",
453 &login_bytes[..login_bytes.len().min(94)]
454 );
455 if login_bytes.len() > 94 {
457 tracing::debug!(
458 "Login7 variable data ({} bytes): {:02X?}",
459 login_bytes.len() - 94,
460 &login_bytes[94..]
461 );
462 }
463
464 let login_header = PacketHeader::new(
466 PacketType::Tds7Login,
467 PacketStatus::END_OF_MESSAGE,
468 (PACKET_HEADER_SIZE + login_bytes.len()) as u16,
469 )
470 .with_packet_id(1);
471 let mut login_packet_buf =
472 BytesMut::with_capacity(PACKET_HEADER_SIZE + login_bytes.len());
473 login_header.encode(&mut login_packet_buf);
474 login_packet_buf.put_slice(&login_bytes);
475
476 tracing::debug!(
477 "Sending Login7 packet: {} bytes total, header: {:02X?}",
478 login_packet_buf.len(),
479 &login_packet_buf[..PACKET_HEADER_SIZE]
480 );
481 tcp_stream
482 .write_all(&login_packet_buf)
483 .await
484 .map_err(|e| Error::Io(Arc::new(e)))?;
485 tcp_stream
486 .flush()
487 .await
488 .map_err(|e| Error::Io(Arc::new(e)))?;
489 tracing::debug!("Login7 sent and flushed over raw TCP");
490
491 let mut response_header_buf = [0u8; PACKET_HEADER_SIZE];
493 tcp_stream
494 .read_exact(&mut response_header_buf)
495 .await
496 .map_err(|e| Error::Io(Arc::new(e)))?;
497
498 let response_type = response_header_buf[0];
499 let response_length =
500 u16::from_be_bytes([response_header_buf[2], response_header_buf[3]]) as usize;
501 tracing::debug!(
502 "Response header: type={:#04X}, length={}",
503 response_type,
504 response_length
505 );
506
507 let payload_length = response_length.saturating_sub(PACKET_HEADER_SIZE);
509 let mut response_payload = vec![0u8; payload_length];
510 tcp_stream
511 .read_exact(&mut response_payload)
512 .await
513 .map_err(|e| Error::Io(Arc::new(e)))?;
514 tracing::debug!(
515 "Response payload: {} bytes, first 32: {:02X?}",
516 response_payload.len(),
517 &response_payload[..response_payload.len().min(32)]
518 );
519
520 let connection = Connection::new(tcp_stream);
522
523 let response_bytes = bytes::Bytes::from(response_payload);
525 let mut parser = TokenParser::new(response_bytes);
526 let mut server_version = None;
527 let mut current_database = None;
528 let routing = None;
529
530 while let Some(token) = parser
531 .next_token()
532 .map_err(|e| Error::Protocol(e.to_string()))?
533 {
534 match token {
535 Token::LoginAck(ack) => {
536 tracing::info!(
537 version = ack.tds_version,
538 interface = ack.interface,
539 prog_name = %ack.prog_name,
540 "login acknowledged"
541 );
542 server_version = Some(ack.tds_version);
543 }
544 Token::EnvChange(env) => {
545 Self::process_env_change(&env, &mut current_database, &mut None);
546 }
547 Token::Error(err) => {
548 return Err(Error::Server {
549 number: err.number,
550 state: err.state,
551 class: err.class,
552 message: err.message.clone(),
553 server: if err.server.is_empty() {
554 None
555 } else {
556 Some(err.server.clone())
557 },
558 procedure: if err.procedure.is_empty() {
559 None
560 } else {
561 Some(err.procedure.clone())
562 },
563 line: err.line as u32,
564 });
565 }
566 Token::Info(info) => {
567 tracing::info!(
568 number = info.number,
569 message = %info.message,
570 "server info message"
571 );
572 }
573 Token::Done(done) => {
574 if done.status.error {
575 return Err(Error::Protocol("login failed".to_string()));
576 }
577 break;
578 }
579 _ => {}
580 }
581 }
582
583 if let Some((host, port)) = routing {
585 return Err(Error::Routing { host, port });
586 }
587
588 Ok(Client {
589 config: config.clone(),
590 _state: PhantomData,
591 connection: Some(ConnectionHandle::Plain(connection)),
592 server_version,
593 current_database: current_database.clone(),
594 statement_cache: StatementCache::with_default_size(),
595 transaction_descriptor: 0, #[cfg(feature = "otel")]
597 instrumentation: InstrumentationContext::new(config.host.clone(), config.port)
598 .with_database(current_database.unwrap_or_default()),
599 })
600 }
601 }
602
603 fn build_prelogin(config: &Config, encryption: EncryptionLevel) -> PreLogin {
605 let mut prelogin = PreLogin::new().with_encryption(encryption);
606
607 if config.mars {
608 prelogin = prelogin.with_mars(true);
609 }
610
611 if let Some(ref instance) = config.instance {
612 prelogin = prelogin.with_instance(instance);
613 }
614
615 prelogin
616 }
617
618 fn build_login7(config: &Config) -> Login7 {
620 let mut login = Login7::new()
621 .with_packet_size(config.packet_size as u32)
622 .with_app_name(&config.application_name)
623 .with_server_name(&config.host)
624 .with_hostname(&config.host);
625
626 if let Some(ref database) = config.database {
627 login = login.with_database(database);
628 }
629
630 match &config.credentials {
632 mssql_auth::Credentials::SqlServer { username, password } => {
633 login = login.with_sql_auth(username.as_ref(), password.as_ref());
634 }
635 _ => {}
637 }
638
639 login
640 }
641
642 async fn send_prelogin<T>(connection: &mut Connection<T>, prelogin: &PreLogin) -> Result<()>
644 where
645 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
646 {
647 let payload = prelogin.encode();
648 let max_packet = MAX_PACKET_SIZE;
649
650 connection
651 .send_message(PacketType::PreLogin, payload, max_packet)
652 .await
653 .map_err(|e| Error::Protocol(e.to_string()))
654 }
655
656 async fn receive_prelogin<T>(connection: &mut Connection<T>) -> Result<PreLogin>
658 where
659 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
660 {
661 let message = connection
662 .read_message()
663 .await
664 .map_err(|e| Error::Protocol(e.to_string()))?
665 .ok_or(Error::ConnectionClosed)?;
666
667 PreLogin::decode(&message.payload[..]).map_err(|e| Error::Protocol(e.to_string()))
668 }
669
670 async fn send_login7<T>(connection: &mut Connection<T>, login: &Login7) -> Result<()>
672 where
673 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
674 {
675 let payload = login.encode();
676 let max_packet = MAX_PACKET_SIZE;
677
678 connection
679 .send_message(PacketType::Tds7Login, payload, max_packet)
680 .await
681 .map_err(|e| Error::Protocol(e.to_string()))
682 }
683
684 async fn process_login_response<T>(
688 connection: &mut Connection<T>,
689 ) -> Result<(Option<u32>, Option<String>, Option<(String, u16)>)>
690 where
691 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
692 {
693 let message = connection
694 .read_message()
695 .await
696 .map_err(|e| Error::Protocol(e.to_string()))?
697 .ok_or(Error::ConnectionClosed)?;
698
699 let response_bytes = message.payload;
700
701 let mut parser = TokenParser::new(response_bytes);
702 let mut server_version = None;
703 let mut database = None;
704 let mut routing = None;
705
706 while let Some(token) = parser
707 .next_token()
708 .map_err(|e| Error::Protocol(e.to_string()))?
709 {
710 match token {
711 Token::LoginAck(ack) => {
712 tracing::info!(
713 version = ack.tds_version,
714 interface = ack.interface,
715 prog_name = %ack.prog_name,
716 "login acknowledged"
717 );
718 server_version = Some(ack.tds_version);
719 }
720 Token::EnvChange(env) => {
721 Self::process_env_change(&env, &mut database, &mut routing);
722 }
723 Token::Error(err) => {
724 return Err(Error::Server {
725 number: err.number,
726 state: err.state,
727 class: err.class,
728 message: err.message.clone(),
729 server: if err.server.is_empty() {
730 None
731 } else {
732 Some(err.server.clone())
733 },
734 procedure: if err.procedure.is_empty() {
735 None
736 } else {
737 Some(err.procedure.clone())
738 },
739 line: err.line as u32,
740 });
741 }
742 Token::Info(info) => {
743 tracing::info!(
744 number = info.number,
745 message = %info.message,
746 "server info message"
747 );
748 }
749 Token::Done(done) => {
750 if done.status.error {
751 return Err(Error::Protocol("login failed".to_string()));
752 }
753 break;
754 }
755 _ => {}
756 }
757 }
758
759 Ok((server_version, database, routing))
760 }
761
762 fn process_env_change(
764 env: &EnvChange,
765 database: &mut Option<String>,
766 routing: &mut Option<(String, u16)>,
767 ) {
768 use tds_protocol::token::EnvChangeValue;
769
770 match env.env_type {
771 EnvChangeType::Database => {
772 if let EnvChangeValue::String(ref new_value) = env.new_value {
773 tracing::debug!(database = %new_value, "database changed");
774 *database = Some(new_value.clone());
775 }
776 }
777 EnvChangeType::Routing => {
778 if let EnvChangeValue::Routing { ref host, port } = env.new_value {
779 tracing::info!(host = %host, port = port, "routing redirect received");
780 *routing = Some((host.clone(), port));
781 }
782 }
783 _ => {
784 if let EnvChangeValue::String(ref new_value) = env.new_value {
785 tracing::debug!(
786 env_type = ?env.env_type,
787 new_value = %new_value,
788 "environment change"
789 );
790 }
791 }
792 }
793 }
794}
795
796impl<S: ConnectionState> Client<S> {
798 fn process_transaction_env_change(env: &EnvChange, transaction_descriptor: &mut u64) {
806 use tds_protocol::token::EnvChangeValue;
807
808 match env.env_type {
809 EnvChangeType::BeginTransaction => {
810 if let EnvChangeValue::Binary(ref data) = env.new_value {
811 if data.len() >= 8 {
812 let descriptor = u64::from_le_bytes([
813 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
814 ]);
815 tracing::debug!(descriptor = descriptor, "transaction started via raw SQL");
816 *transaction_descriptor = descriptor;
817 }
818 }
819 }
820 EnvChangeType::CommitTransaction | EnvChangeType::RollbackTransaction => {
821 tracing::debug!(
822 env_type = ?env.env_type,
823 "transaction ended via raw SQL"
824 );
825 *transaction_descriptor = 0;
826 }
827 _ => {}
828 }
829 }
830
831 async fn send_sql_batch(&mut self, sql: &str) -> Result<()> {
837 let payload =
838 tds_protocol::encode_sql_batch_with_transaction(sql, self.transaction_descriptor);
839 let max_packet = self.config.packet_size as usize;
840
841 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
842
843 match connection {
844 ConnectionHandle::Tls(conn) => {
845 conn.send_message(PacketType::SqlBatch, payload, max_packet)
846 .await
847 .map_err(|e| Error::Protocol(e.to_string()))?;
848 }
849 ConnectionHandle::TlsPrelogin(conn) => {
850 conn.send_message(PacketType::SqlBatch, payload, max_packet)
851 .await
852 .map_err(|e| Error::Protocol(e.to_string()))?;
853 }
854 ConnectionHandle::Plain(conn) => {
855 conn.send_message(PacketType::SqlBatch, payload, max_packet)
856 .await
857 .map_err(|e| Error::Protocol(e.to_string()))?;
858 }
859 }
860
861 Ok(())
862 }
863
864 async fn send_rpc(&mut self, rpc: &RpcRequest) -> Result<()> {
868 let payload = rpc.encode_with_transaction(self.transaction_descriptor);
869 let max_packet = self.config.packet_size as usize;
870
871 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
872
873 match connection {
874 ConnectionHandle::Tls(conn) => {
875 conn.send_message(PacketType::Rpc, payload, max_packet)
876 .await
877 .map_err(|e| Error::Protocol(e.to_string()))?;
878 }
879 ConnectionHandle::TlsPrelogin(conn) => {
880 conn.send_message(PacketType::Rpc, payload, max_packet)
881 .await
882 .map_err(|e| Error::Protocol(e.to_string()))?;
883 }
884 ConnectionHandle::Plain(conn) => {
885 conn.send_message(PacketType::Rpc, payload, max_packet)
886 .await
887 .map_err(|e| Error::Protocol(e.to_string()))?;
888 }
889 }
890
891 Ok(())
892 }
893
894 fn convert_params(params: &[&(dyn crate::ToSql + Sync)]) -> Result<Vec<RpcParam>> {
896 use bytes::{BufMut, BytesMut};
897 use mssql_types::SqlValue;
898
899 params
900 .iter()
901 .enumerate()
902 .map(|(i, p)| {
903 let sql_value = p.to_sql()?;
904 let name = format!("@p{}", i + 1);
905
906 Ok(match sql_value {
907 SqlValue::Null => RpcParam::null(&name, RpcTypeInfo::nvarchar(1)),
908 SqlValue::Bool(v) => {
909 let mut buf = BytesMut::with_capacity(1);
910 buf.put_u8(if v { 1 } else { 0 });
911 RpcParam::new(&name, RpcTypeInfo::bit(), buf.freeze())
912 }
913 SqlValue::TinyInt(v) => {
914 let mut buf = BytesMut::with_capacity(1);
915 buf.put_u8(v);
916 RpcParam::new(&name, RpcTypeInfo::tinyint(), buf.freeze())
917 }
918 SqlValue::SmallInt(v) => {
919 let mut buf = BytesMut::with_capacity(2);
920 buf.put_i16_le(v);
921 RpcParam::new(&name, RpcTypeInfo::smallint(), buf.freeze())
922 }
923 SqlValue::Int(v) => RpcParam::int(&name, v),
924 SqlValue::BigInt(v) => RpcParam::bigint(&name, v),
925 SqlValue::Float(v) => {
926 let mut buf = BytesMut::with_capacity(4);
927 buf.put_f32_le(v);
928 RpcParam::new(&name, RpcTypeInfo::real(), buf.freeze())
929 }
930 SqlValue::Double(v) => {
931 let mut buf = BytesMut::with_capacity(8);
932 buf.put_f64_le(v);
933 RpcParam::new(&name, RpcTypeInfo::float(), buf.freeze())
934 }
935 SqlValue::String(ref s) => RpcParam::nvarchar(&name, s),
936 SqlValue::Binary(ref b) => {
937 RpcParam::new(&name, RpcTypeInfo::varbinary(b.len() as u16), b.clone())
938 }
939 SqlValue::Xml(ref s) => RpcParam::nvarchar(&name, s),
940 #[cfg(feature = "uuid")]
941 SqlValue::Uuid(u) => {
942 let bytes = u.as_bytes();
944 let mut buf = BytesMut::with_capacity(16);
945 buf.put_u32_le(u32::from_be_bytes([
947 bytes[0], bytes[1], bytes[2], bytes[3],
948 ]));
949 buf.put_u16_le(u16::from_be_bytes([bytes[4], bytes[5]]));
950 buf.put_u16_le(u16::from_be_bytes([bytes[6], bytes[7]]));
951 buf.put_slice(&bytes[8..16]);
952 RpcParam::new(&name, RpcTypeInfo::uniqueidentifier(), buf.freeze())
953 }
954 #[cfg(feature = "decimal")]
955 SqlValue::Decimal(d) => {
956 RpcParam::nvarchar(&name, &d.to_string())
958 }
959 #[cfg(feature = "chrono")]
960 SqlValue::Date(_)
961 | SqlValue::Time(_)
962 | SqlValue::DateTime(_)
963 | SqlValue::DateTimeOffset(_) => {
964 let s = match &sql_value {
967 SqlValue::Date(d) => d.to_string(),
968 SqlValue::Time(t) => t.to_string(),
969 SqlValue::DateTime(dt) => dt.to_string(),
970 SqlValue::DateTimeOffset(dto) => dto.to_rfc3339(),
971 _ => unreachable!(),
972 };
973 RpcParam::nvarchar(&name, &s)
974 }
975 #[cfg(feature = "json")]
976 SqlValue::Json(ref j) => RpcParam::nvarchar(&name, &j.to_string()),
977 SqlValue::Tvp(ref tvp_data) => {
978 Self::encode_tvp_param(&name, tvp_data)?
980 }
981 _ => {
983 return Err(Error::Type(mssql_types::TypeError::UnsupportedConversion {
984 from: sql_value.type_name().to_string(),
985 to: "RPC parameter",
986 }));
987 }
988 })
989 })
990 .collect()
991 }
992
993 fn encode_tvp_param(name: &str, tvp_data: &mssql_types::TvpData) -> Result<RpcParam> {
998 let wire_columns: Vec<TvpWireColumnDef> = tvp_data
1000 .columns
1001 .iter()
1002 .map(|col| {
1003 let wire_type = Self::convert_tvp_column_type(&col.column_type);
1004 TvpWireColumnDef {
1005 wire_type,
1006 flags: TvpColumnFlags {
1007 nullable: col.nullable,
1008 },
1009 }
1010 })
1011 .collect();
1012
1013 let encoder = TvpEncoder::new(&tvp_data.schema, &tvp_data.type_name, &wire_columns);
1015
1016 let mut buf = BytesMut::with_capacity(256);
1018
1019 encoder.encode_metadata(&mut buf);
1021
1022 for row in &tvp_data.rows {
1024 encoder.encode_row(&mut buf, |row_buf| {
1025 for (col_idx, value) in row.iter().enumerate() {
1026 let wire_type = &wire_columns[col_idx].wire_type;
1027 Self::encode_tvp_value(value, wire_type, row_buf);
1028 }
1029 });
1030 }
1031
1032 encoder.encode_end(&mut buf);
1034
1035 let type_info = RpcTypeInfo {
1039 type_id: 0xF3, max_length: None,
1041 precision: None,
1042 scale: None,
1043 collation: None,
1044 };
1045
1046 Ok(RpcParam {
1047 name: name.to_string(),
1048 flags: tds_protocol::rpc::ParamFlags::default(),
1049 type_info,
1050 value: Some(buf.freeze()),
1051 })
1052 }
1053
1054 fn convert_tvp_column_type(col_type: &mssql_types::TvpColumnType) -> TvpWireType {
1056 match col_type {
1057 mssql_types::TvpColumnType::Bit => TvpWireType::Bit,
1058 mssql_types::TvpColumnType::TinyInt => TvpWireType::Int { size: 1 },
1059 mssql_types::TvpColumnType::SmallInt => TvpWireType::Int { size: 2 },
1060 mssql_types::TvpColumnType::Int => TvpWireType::Int { size: 4 },
1061 mssql_types::TvpColumnType::BigInt => TvpWireType::Int { size: 8 },
1062 mssql_types::TvpColumnType::Real => TvpWireType::Float { size: 4 },
1063 mssql_types::TvpColumnType::Float => TvpWireType::Float { size: 8 },
1064 mssql_types::TvpColumnType::Decimal { precision, scale } => TvpWireType::Decimal {
1065 precision: *precision,
1066 scale: *scale,
1067 },
1068 mssql_types::TvpColumnType::NVarChar { max_length } => TvpWireType::NVarChar {
1069 max_length: *max_length,
1070 },
1071 mssql_types::TvpColumnType::VarChar { max_length } => TvpWireType::VarChar {
1072 max_length: *max_length,
1073 },
1074 mssql_types::TvpColumnType::VarBinary { max_length } => TvpWireType::VarBinary {
1075 max_length: *max_length,
1076 },
1077 mssql_types::TvpColumnType::UniqueIdentifier => TvpWireType::Guid,
1078 mssql_types::TvpColumnType::Date => TvpWireType::Date,
1079 mssql_types::TvpColumnType::Time { scale } => TvpWireType::Time { scale: *scale },
1080 mssql_types::TvpColumnType::DateTime2 { scale } => {
1081 TvpWireType::DateTime2 { scale: *scale }
1082 }
1083 mssql_types::TvpColumnType::DateTimeOffset { scale } => {
1084 TvpWireType::DateTimeOffset { scale: *scale }
1085 }
1086 mssql_types::TvpColumnType::Xml => TvpWireType::Xml,
1087 }
1088 }
1089
1090 fn encode_tvp_value(
1092 value: &mssql_types::SqlValue,
1093 wire_type: &TvpWireType,
1094 buf: &mut BytesMut,
1095 ) {
1096 use mssql_types::SqlValue;
1097
1098 match value {
1099 SqlValue::Null => {
1100 encode_tvp_null(wire_type, buf);
1101 }
1102 SqlValue::Bool(v) => {
1103 encode_tvp_bit(*v, buf);
1104 }
1105 SqlValue::TinyInt(v) => {
1106 encode_tvp_int(*v as i64, 1, buf);
1107 }
1108 SqlValue::SmallInt(v) => {
1109 encode_tvp_int(*v as i64, 2, buf);
1110 }
1111 SqlValue::Int(v) => {
1112 encode_tvp_int(*v as i64, 4, buf);
1113 }
1114 SqlValue::BigInt(v) => {
1115 encode_tvp_int(*v, 8, buf);
1116 }
1117 SqlValue::Float(v) => {
1118 encode_tvp_float(*v as f64, 4, buf);
1119 }
1120 SqlValue::Double(v) => {
1121 encode_tvp_float(*v, 8, buf);
1122 }
1123 SqlValue::String(s) => {
1124 let max_len = match wire_type {
1125 TvpWireType::NVarChar { max_length } => *max_length,
1126 _ => 4000,
1127 };
1128 encode_tvp_nvarchar(s, max_len, buf);
1129 }
1130 SqlValue::Binary(b) => {
1131 let max_len = match wire_type {
1132 TvpWireType::VarBinary { max_length } => *max_length,
1133 _ => 8000,
1134 };
1135 encode_tvp_varbinary(b, max_len, buf);
1136 }
1137 #[cfg(feature = "decimal")]
1138 SqlValue::Decimal(d) => {
1139 let sign = if d.is_sign_negative() { 0u8 } else { 1u8 };
1140 let mantissa = d.mantissa().unsigned_abs();
1141 encode_tvp_decimal(sign, mantissa, buf);
1142 }
1143 #[cfg(feature = "uuid")]
1144 SqlValue::Uuid(u) => {
1145 let bytes = u.as_bytes();
1146 tds_protocol::tvp::encode_tvp_guid(bytes, buf);
1147 }
1148 #[cfg(feature = "chrono")]
1149 SqlValue::Date(d) => {
1150 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1152 let days = d.signed_duration_since(base).num_days() as u32;
1153 tds_protocol::tvp::encode_tvp_date(days, buf);
1154 }
1155 #[cfg(feature = "chrono")]
1156 SqlValue::Time(t) => {
1157 use chrono::Timelike;
1158 let nanos =
1159 t.num_seconds_from_midnight() as u64 * 1_000_000_000 + t.nanosecond() as u64;
1160 let intervals = nanos / 100;
1161 let scale = match wire_type {
1162 TvpWireType::Time { scale } => *scale,
1163 _ => 7,
1164 };
1165 tds_protocol::tvp::encode_tvp_time(intervals, scale, buf);
1166 }
1167 #[cfg(feature = "chrono")]
1168 SqlValue::DateTime(dt) => {
1169 use chrono::Timelike;
1170 let nanos = dt.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1172 + dt.time().nanosecond() as u64;
1173 let intervals = nanos / 100;
1174 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1176 let days = dt.date().signed_duration_since(base).num_days() as u32;
1177 let scale = match wire_type {
1178 TvpWireType::DateTime2 { scale } => *scale,
1179 _ => 7,
1180 };
1181 tds_protocol::tvp::encode_tvp_datetime2(intervals, days, scale, buf);
1182 }
1183 #[cfg(feature = "chrono")]
1184 SqlValue::DateTimeOffset(dto) => {
1185 use chrono::{Offset, Timelike};
1186 let nanos = dto.time().num_seconds_from_midnight() as u64 * 1_000_000_000
1188 + dto.time().nanosecond() as u64;
1189 let intervals = nanos / 100;
1190 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1192 let days = dto.date_naive().signed_duration_since(base).num_days() as u32;
1193 let offset_minutes = (dto.offset().fix().local_minus_utc() / 60) as i16;
1195 let scale = match wire_type {
1196 TvpWireType::DateTimeOffset { scale } => *scale,
1197 _ => 7,
1198 };
1199 tds_protocol::tvp::encode_tvp_datetimeoffset(
1200 intervals,
1201 days,
1202 offset_minutes,
1203 scale,
1204 buf,
1205 );
1206 }
1207 #[cfg(feature = "json")]
1208 SqlValue::Json(j) => {
1209 encode_tvp_nvarchar(&j.to_string(), 0xFFFF, buf);
1211 }
1212 SqlValue::Xml(s) => {
1213 encode_tvp_nvarchar(s, 0xFFFF, buf);
1215 }
1216 SqlValue::Tvp(_) => {
1217 encode_tvp_null(wire_type, buf);
1219 }
1220 _ => {
1222 encode_tvp_null(wire_type, buf);
1223 }
1224 }
1225 }
1226
1227 async fn read_query_response(
1229 &mut self,
1230 ) -> Result<(Vec<crate::row::Column>, Vec<crate::row::Row>)> {
1231 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
1232
1233 let message = match connection {
1234 ConnectionHandle::Tls(conn) => conn
1235 .read_message()
1236 .await
1237 .map_err(|e| Error::Protocol(e.to_string()))?,
1238 ConnectionHandle::TlsPrelogin(conn) => conn
1239 .read_message()
1240 .await
1241 .map_err(|e| Error::Protocol(e.to_string()))?,
1242 ConnectionHandle::Plain(conn) => conn
1243 .read_message()
1244 .await
1245 .map_err(|e| Error::Protocol(e.to_string()))?,
1246 }
1247 .ok_or(Error::ConnectionClosed)?;
1248
1249 let mut parser = TokenParser::new(message.payload);
1250 let mut columns: Vec<crate::row::Column> = Vec::new();
1251 let mut rows: Vec<crate::row::Row> = Vec::new();
1252 let mut protocol_metadata: Option<ColMetaData> = None;
1253
1254 loop {
1255 let token = parser
1257 .next_token_with_metadata(protocol_metadata.as_ref())
1258 .map_err(|e| Error::Protocol(e.to_string()))?;
1259
1260 let Some(token) = token else {
1261 break;
1262 };
1263
1264 match token {
1265 Token::ColMetaData(meta) => {
1266 rows.clear();
1269
1270 columns = meta
1271 .columns
1272 .iter()
1273 .enumerate()
1274 .map(|(i, col)| {
1275 let type_name = format!("{:?}", col.type_id);
1276 let mut column = crate::row::Column::new(&col.name, i, type_name)
1277 .with_nullable(col.flags & 0x01 != 0);
1278
1279 if let Some(max_len) = col.type_info.max_length {
1280 column = column.with_max_length(max_len);
1281 }
1282 if let (Some(prec), Some(scale)) =
1283 (col.type_info.precision, col.type_info.scale)
1284 {
1285 column = column.with_precision_scale(prec, scale);
1286 }
1287 column
1288 })
1289 .collect();
1290
1291 tracing::debug!(columns = columns.len(), "received column metadata");
1292 protocol_metadata = Some(meta);
1293 }
1294 Token::Row(raw_row) => {
1295 if let Some(ref meta) = protocol_metadata {
1296 let row = Self::convert_raw_row(&raw_row, meta, &columns)?;
1297 rows.push(row);
1298 }
1299 }
1300 Token::NbcRow(nbc_row) => {
1301 if let Some(ref meta) = protocol_metadata {
1302 let row = Self::convert_nbc_row(&nbc_row, meta, &columns)?;
1303 rows.push(row);
1304 }
1305 }
1306 Token::Error(err) => {
1307 return Err(Error::Server {
1308 number: err.number,
1309 state: err.state,
1310 class: err.class,
1311 message: err.message.clone(),
1312 server: if err.server.is_empty() {
1313 None
1314 } else {
1315 Some(err.server.clone())
1316 },
1317 procedure: if err.procedure.is_empty() {
1318 None
1319 } else {
1320 Some(err.procedure.clone())
1321 },
1322 line: err.line as u32,
1323 });
1324 }
1325 Token::Done(done) => {
1326 if done.status.error {
1327 return Err(Error::Query("query failed".to_string()));
1328 }
1329 tracing::debug!(
1330 row_count = done.row_count,
1331 has_more = done.status.more,
1332 "query complete"
1333 );
1334 if !done.status.more {
1337 break;
1338 }
1339 }
1340 Token::DoneProc(done) => {
1341 if done.status.error {
1342 return Err(Error::Query("query failed".to_string()));
1343 }
1344 }
1345 Token::DoneInProc(done) => {
1346 if done.status.error {
1347 return Err(Error::Query("query failed".to_string()));
1348 }
1349 }
1350 Token::Info(info) => {
1351 tracing::debug!(
1352 number = info.number,
1353 message = %info.message,
1354 "server info message"
1355 );
1356 }
1357 Token::EnvChange(env) => {
1358 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
1362 }
1363 _ => {}
1364 }
1365 }
1366
1367 tracing::debug!(
1368 columns = columns.len(),
1369 rows = rows.len(),
1370 "query response parsed"
1371 );
1372 Ok((columns, rows))
1373 }
1374
1375 fn convert_raw_row(
1379 raw: &RawRow,
1380 meta: &ColMetaData,
1381 columns: &[crate::row::Column],
1382 ) -> Result<crate::row::Row> {
1383 let mut values = Vec::with_capacity(meta.columns.len());
1384 let mut buf = raw.data.as_ref();
1385
1386 for col in &meta.columns {
1387 let value = Self::parse_column_value(&mut buf, col)?;
1388 values.push(value);
1389 }
1390
1391 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1392 }
1393
1394 fn convert_nbc_row(
1398 nbc: &NbcRow,
1399 meta: &ColMetaData,
1400 columns: &[crate::row::Column],
1401 ) -> Result<crate::row::Row> {
1402 let mut values = Vec::with_capacity(meta.columns.len());
1403 let mut buf = nbc.data.as_ref();
1404
1405 for (i, col) in meta.columns.iter().enumerate() {
1406 if nbc.is_null(i) {
1407 values.push(mssql_types::SqlValue::Null);
1408 } else {
1409 let value = Self::parse_column_value(&mut buf, col)?;
1410 values.push(value);
1411 }
1412 }
1413
1414 Ok(crate::row::Row::from_values(columns.to_vec(), values))
1415 }
1416
1417 fn parse_column_value(buf: &mut &[u8], col: &ColumnData) -> Result<mssql_types::SqlValue> {
1419 use bytes::Buf;
1420 use mssql_types::SqlValue;
1421 use tds_protocol::types::TypeId;
1422
1423 let value = match col.type_id {
1424 TypeId::Null => SqlValue::Null,
1426
1427 TypeId::Int1 => {
1429 if buf.remaining() < 1 {
1430 return Err(Error::Protocol("unexpected EOF reading TINYINT".into()));
1431 }
1432 SqlValue::TinyInt(buf.get_u8())
1433 }
1434 TypeId::Bit => {
1435 if buf.remaining() < 1 {
1436 return Err(Error::Protocol("unexpected EOF reading BIT".into()));
1437 }
1438 SqlValue::Bool(buf.get_u8() != 0)
1439 }
1440
1441 TypeId::Int2 => {
1443 if buf.remaining() < 2 {
1444 return Err(Error::Protocol("unexpected EOF reading SMALLINT".into()));
1445 }
1446 SqlValue::SmallInt(buf.get_i16_le())
1447 }
1448
1449 TypeId::Int4 => {
1451 if buf.remaining() < 4 {
1452 return Err(Error::Protocol("unexpected EOF reading INT".into()));
1453 }
1454 SqlValue::Int(buf.get_i32_le())
1455 }
1456 TypeId::Float4 => {
1457 if buf.remaining() < 4 {
1458 return Err(Error::Protocol("unexpected EOF reading REAL".into()));
1459 }
1460 SqlValue::Float(buf.get_f32_le())
1461 }
1462
1463 TypeId::Int8 => {
1465 if buf.remaining() < 8 {
1466 return Err(Error::Protocol("unexpected EOF reading BIGINT".into()));
1467 }
1468 SqlValue::BigInt(buf.get_i64_le())
1469 }
1470 TypeId::Float8 => {
1471 if buf.remaining() < 8 {
1472 return Err(Error::Protocol("unexpected EOF reading FLOAT".into()));
1473 }
1474 SqlValue::Double(buf.get_f64_le())
1475 }
1476 TypeId::Money => {
1477 if buf.remaining() < 8 {
1478 return Err(Error::Protocol("unexpected EOF reading MONEY".into()));
1479 }
1480 let high = buf.get_i32_le();
1482 let low = buf.get_u32_le();
1483 let cents = ((high as i64) << 32) | (low as i64);
1484 let value = (cents as f64) / 10000.0;
1485 SqlValue::Double(value)
1486 }
1487 TypeId::Money4 => {
1488 if buf.remaining() < 4 {
1489 return Err(Error::Protocol("unexpected EOF reading SMALLMONEY".into()));
1490 }
1491 let cents = buf.get_i32_le();
1492 let value = (cents as f64) / 10000.0;
1493 SqlValue::Double(value)
1494 }
1495
1496 TypeId::IntN => {
1498 if buf.remaining() < 1 {
1499 return Err(Error::Protocol("unexpected EOF reading IntN length".into()));
1500 }
1501 let len = buf.get_u8();
1502 match len {
1503 0 => SqlValue::Null,
1504 1 => SqlValue::TinyInt(buf.get_u8()),
1505 2 => SqlValue::SmallInt(buf.get_i16_le()),
1506 4 => SqlValue::Int(buf.get_i32_le()),
1507 8 => SqlValue::BigInt(buf.get_i64_le()),
1508 _ => {
1509 return Err(Error::Protocol(format!("invalid IntN length: {len}")));
1510 }
1511 }
1512 }
1513 TypeId::FloatN => {
1514 if buf.remaining() < 1 {
1515 return Err(Error::Protocol(
1516 "unexpected EOF reading FloatN length".into(),
1517 ));
1518 }
1519 let len = buf.get_u8();
1520 match len {
1521 0 => SqlValue::Null,
1522 4 => SqlValue::Float(buf.get_f32_le()),
1523 8 => SqlValue::Double(buf.get_f64_le()),
1524 _ => {
1525 return Err(Error::Protocol(format!("invalid FloatN length: {len}")));
1526 }
1527 }
1528 }
1529 TypeId::BitN => {
1530 if buf.remaining() < 1 {
1531 return Err(Error::Protocol("unexpected EOF reading BitN length".into()));
1532 }
1533 let len = buf.get_u8();
1534 match len {
1535 0 => SqlValue::Null,
1536 1 => SqlValue::Bool(buf.get_u8() != 0),
1537 _ => {
1538 return Err(Error::Protocol(format!("invalid BitN length: {len}")));
1539 }
1540 }
1541 }
1542 TypeId::MoneyN => {
1543 if buf.remaining() < 1 {
1544 return Err(Error::Protocol(
1545 "unexpected EOF reading MoneyN length".into(),
1546 ));
1547 }
1548 let len = buf.get_u8();
1549 match len {
1550 0 => SqlValue::Null,
1551 4 => {
1552 let cents = buf.get_i32_le();
1553 SqlValue::Double((cents as f64) / 10000.0)
1554 }
1555 8 => {
1556 let high = buf.get_i32_le();
1557 let low = buf.get_u32_le();
1558 let cents = ((high as i64) << 32) | (low as i64);
1559 SqlValue::Double((cents as f64) / 10000.0)
1560 }
1561 _ => {
1562 return Err(Error::Protocol(format!("invalid MoneyN length: {len}")));
1563 }
1564 }
1565 }
1566 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1568 if buf.remaining() < 1 {
1569 return Err(Error::Protocol(
1570 "unexpected EOF reading DECIMAL/NUMERIC length".into(),
1571 ));
1572 }
1573 let len = buf.get_u8() as usize;
1574 if len == 0 {
1575 SqlValue::Null
1576 } else {
1577 if buf.remaining() < len {
1578 return Err(Error::Protocol(
1579 "unexpected EOF reading DECIMAL/NUMERIC data".into(),
1580 ));
1581 }
1582
1583 let sign = buf.get_u8();
1585 let mantissa_len = len - 1;
1586
1587 let mut mantissa_bytes = [0u8; 16];
1589 for i in 0..mantissa_len.min(16) {
1590 mantissa_bytes[i] = buf.get_u8();
1591 }
1592 for _ in 16..mantissa_len {
1594 buf.get_u8();
1595 }
1596
1597 let mantissa = u128::from_le_bytes(mantissa_bytes);
1598 let scale = col.type_info.scale.unwrap_or(0) as u32;
1599
1600 #[cfg(feature = "decimal")]
1601 {
1602 use rust_decimal::Decimal;
1603 let mut decimal = Decimal::from_i128_with_scale(mantissa as i128, scale);
1604 if sign == 0 {
1605 decimal.set_sign_negative(true);
1606 }
1607 SqlValue::Decimal(decimal)
1608 }
1609
1610 #[cfg(not(feature = "decimal"))]
1611 {
1612 let divisor = 10f64.powi(scale as i32);
1614 let value = (mantissa as f64) / divisor;
1615 let value = if sign == 0 { -value } else { value };
1616 SqlValue::Double(value)
1617 }
1618 }
1619 }
1620
1621 TypeId::DateTimeN => {
1623 if buf.remaining() < 1 {
1624 return Err(Error::Protocol(
1625 "unexpected EOF reading DateTimeN length".into(),
1626 ));
1627 }
1628 let len = buf.get_u8() as usize;
1629 if len == 0 {
1630 SqlValue::Null
1631 } else if buf.remaining() < len {
1632 return Err(Error::Protocol("unexpected EOF reading DateTimeN".into()));
1633 } else {
1634 match len {
1635 4 => {
1636 let days = buf.get_u16_le() as i64;
1638 let minutes = buf.get_u16_le() as u32;
1639 #[cfg(feature = "chrono")]
1640 {
1641 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1642 let date = base + chrono::Duration::days(days);
1643 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1644 minutes * 60,
1645 0,
1646 )
1647 .unwrap();
1648 SqlValue::DateTime(date.and_time(time))
1649 }
1650 #[cfg(not(feature = "chrono"))]
1651 {
1652 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1653 }
1654 }
1655 8 => {
1656 let days = buf.get_i32_le() as i64;
1658 let time_300ths = buf.get_u32_le() as u64;
1659 #[cfg(feature = "chrono")]
1660 {
1661 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1662 let date = base + chrono::Duration::days(days);
1663 let total_ms = (time_300ths * 1000) / 300;
1665 let secs = (total_ms / 1000) as u32;
1666 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1667 let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(
1668 secs, nanos,
1669 )
1670 .unwrap();
1671 SqlValue::DateTime(date.and_time(time))
1672 }
1673 #[cfg(not(feature = "chrono"))]
1674 {
1675 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1676 }
1677 }
1678 _ => {
1679 return Err(Error::Protocol(format!(
1680 "invalid DateTimeN length: {len}"
1681 )));
1682 }
1683 }
1684 }
1685 }
1686
1687 TypeId::DateTime => {
1689 if buf.remaining() < 8 {
1690 return Err(Error::Protocol("unexpected EOF reading DATETIME".into()));
1691 }
1692 let days = buf.get_i32_le() as i64;
1693 let time_300ths = buf.get_u32_le() as u64;
1694 #[cfg(feature = "chrono")]
1695 {
1696 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1697 let date = base + chrono::Duration::days(days);
1698 let total_ms = (time_300ths * 1000) / 300;
1699 let secs = (total_ms / 1000) as u32;
1700 let nanos = ((total_ms % 1000) * 1_000_000) as u32;
1701 let time =
1702 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).unwrap();
1703 SqlValue::DateTime(date.and_time(time))
1704 }
1705 #[cfg(not(feature = "chrono"))]
1706 {
1707 SqlValue::String(format!("DATETIME({days},{time_300ths})"))
1708 }
1709 }
1710
1711 TypeId::DateTime4 => {
1713 if buf.remaining() < 4 {
1714 return Err(Error::Protocol(
1715 "unexpected EOF reading SMALLDATETIME".into(),
1716 ));
1717 }
1718 let days = buf.get_u16_le() as i64;
1719 let minutes = buf.get_u16_le() as u32;
1720 #[cfg(feature = "chrono")]
1721 {
1722 let base = chrono::NaiveDate::from_ymd_opt(1900, 1, 1).unwrap();
1723 let date = base + chrono::Duration::days(days);
1724 let time =
1725 chrono::NaiveTime::from_num_seconds_from_midnight_opt(minutes * 60, 0)
1726 .unwrap();
1727 SqlValue::DateTime(date.and_time(time))
1728 }
1729 #[cfg(not(feature = "chrono"))]
1730 {
1731 SqlValue::String(format!("SMALLDATETIME({days},{minutes})"))
1732 }
1733 }
1734
1735 TypeId::Date => {
1737 if buf.remaining() < 1 {
1738 return Err(Error::Protocol("unexpected EOF reading DATE length".into()));
1739 }
1740 let len = buf.get_u8() as usize;
1741 if len == 0 {
1742 SqlValue::Null
1743 } else if len != 3 {
1744 return Err(Error::Protocol(format!("invalid DATE length: {len}")));
1745 } else if buf.remaining() < 3 {
1746 return Err(Error::Protocol("unexpected EOF reading DATE".into()));
1747 } else {
1748 let days = buf.get_u8() as u32
1750 | ((buf.get_u8() as u32) << 8)
1751 | ((buf.get_u8() as u32) << 16);
1752 #[cfg(feature = "chrono")]
1753 {
1754 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1755 let date = base + chrono::Duration::days(days as i64);
1756 SqlValue::Date(date)
1757 }
1758 #[cfg(not(feature = "chrono"))]
1759 {
1760 SqlValue::String(format!("DATE({days})"))
1761 }
1762 }
1763 }
1764
1765 TypeId::Time => {
1767 if buf.remaining() < 1 {
1768 return Err(Error::Protocol("unexpected EOF reading TIME length".into()));
1769 }
1770 let len = buf.get_u8() as usize;
1771 if len == 0 {
1772 SqlValue::Null
1773 } else if buf.remaining() < len {
1774 return Err(Error::Protocol("unexpected EOF reading TIME".into()));
1775 } else {
1776 let scale = col.type_info.scale.unwrap_or(7);
1777 let mut time_bytes = [0u8; 8];
1778 for byte in time_bytes.iter_mut().take(len) {
1779 *byte = buf.get_u8();
1780 }
1781 let intervals = u64::from_le_bytes(time_bytes);
1782 #[cfg(feature = "chrono")]
1783 {
1784 let time = Self::intervals_to_time(intervals, scale);
1785 SqlValue::Time(time)
1786 }
1787 #[cfg(not(feature = "chrono"))]
1788 {
1789 SqlValue::String(format!("TIME({intervals})"))
1790 }
1791 }
1792 }
1793
1794 TypeId::DateTime2 => {
1796 if buf.remaining() < 1 {
1797 return Err(Error::Protocol(
1798 "unexpected EOF reading DATETIME2 length".into(),
1799 ));
1800 }
1801 let len = buf.get_u8() as usize;
1802 if len == 0 {
1803 SqlValue::Null
1804 } else if buf.remaining() < len {
1805 return Err(Error::Protocol("unexpected EOF reading DATETIME2".into()));
1806 } else {
1807 let scale = col.type_info.scale.unwrap_or(7);
1808 let time_len = Self::time_bytes_for_scale(scale);
1809
1810 let mut time_bytes = [0u8; 8];
1812 for byte in time_bytes.iter_mut().take(time_len) {
1813 *byte = buf.get_u8();
1814 }
1815 let intervals = u64::from_le_bytes(time_bytes);
1816
1817 let days = buf.get_u8() as u32
1819 | ((buf.get_u8() as u32) << 8)
1820 | ((buf.get_u8() as u32) << 16);
1821
1822 #[cfg(feature = "chrono")]
1823 {
1824 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1825 let date = base + chrono::Duration::days(days as i64);
1826 let time = Self::intervals_to_time(intervals, scale);
1827 SqlValue::DateTime(date.and_time(time))
1828 }
1829 #[cfg(not(feature = "chrono"))]
1830 {
1831 SqlValue::String(format!("DATETIME2({days},{intervals})"))
1832 }
1833 }
1834 }
1835
1836 TypeId::DateTimeOffset => {
1838 if buf.remaining() < 1 {
1839 return Err(Error::Protocol(
1840 "unexpected EOF reading DATETIMEOFFSET length".into(),
1841 ));
1842 }
1843 let len = buf.get_u8() as usize;
1844 if len == 0 {
1845 SqlValue::Null
1846 } else if buf.remaining() < len {
1847 return Err(Error::Protocol(
1848 "unexpected EOF reading DATETIMEOFFSET".into(),
1849 ));
1850 } else {
1851 let scale = col.type_info.scale.unwrap_or(7);
1852 let time_len = Self::time_bytes_for_scale(scale);
1853
1854 let mut time_bytes = [0u8; 8];
1856 for byte in time_bytes.iter_mut().take(time_len) {
1857 *byte = buf.get_u8();
1858 }
1859 let intervals = u64::from_le_bytes(time_bytes);
1860
1861 let days = buf.get_u8() as u32
1863 | ((buf.get_u8() as u32) << 8)
1864 | ((buf.get_u8() as u32) << 16);
1865
1866 let offset_minutes = buf.get_i16_le();
1868
1869 #[cfg(feature = "chrono")]
1870 {
1871 use chrono::TimeZone;
1872 let base = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap();
1873 let date = base + chrono::Duration::days(days as i64);
1874 let time = Self::intervals_to_time(intervals, scale);
1875 let offset = chrono::FixedOffset::east_opt((offset_minutes as i32) * 60)
1876 .unwrap_or_else(|| chrono::FixedOffset::east_opt(0).unwrap());
1877 let datetime = offset
1878 .from_local_datetime(&date.and_time(time))
1879 .single()
1880 .unwrap_or_else(|| offset.from_utc_datetime(&date.and_time(time)));
1881 SqlValue::DateTimeOffset(datetime)
1882 }
1883 #[cfg(not(feature = "chrono"))]
1884 {
1885 SqlValue::String(format!(
1886 "DATETIMEOFFSET({days},{intervals},{offset_minutes})"
1887 ))
1888 }
1889 }
1890 }
1891
1892 TypeId::BigVarChar | TypeId::BigChar | TypeId::Text => {
1894 if buf.remaining() < 2 {
1896 return Err(Error::Protocol(
1897 "unexpected EOF reading varchar length".into(),
1898 ));
1899 }
1900 let len = buf.get_u16_le();
1901 if len == 0xFFFF {
1902 SqlValue::Null
1903 } else if buf.remaining() < len as usize {
1904 return Err(Error::Protocol(
1905 "unexpected EOF reading varchar data".into(),
1906 ));
1907 } else {
1908 let data = &buf[..len as usize];
1909 let s = String::from_utf8_lossy(data).into_owned();
1910 buf.advance(len as usize);
1911 SqlValue::String(s)
1912 }
1913 }
1914 TypeId::NVarChar | TypeId::NChar | TypeId::NText => {
1915 if buf.remaining() < 2 {
1917 return Err(Error::Protocol(
1918 "unexpected EOF reading nvarchar length".into(),
1919 ));
1920 }
1921 let len = buf.get_u16_le();
1922 if len == 0xFFFF {
1923 SqlValue::Null
1924 } else if buf.remaining() < len as usize {
1925 return Err(Error::Protocol(
1926 "unexpected EOF reading nvarchar data".into(),
1927 ));
1928 } else {
1929 let data = &buf[..len as usize];
1930 let utf16: Vec<u16> = data
1932 .chunks_exact(2)
1933 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
1934 .collect();
1935 let s = String::from_utf16(&utf16)
1936 .map_err(|_| Error::Protocol("invalid UTF-16 in nvarchar".into()))?;
1937 buf.advance(len as usize);
1938 SqlValue::String(s)
1939 }
1940 }
1941
1942 TypeId::BigVarBinary | TypeId::BigBinary | TypeId::Image => {
1944 if buf.remaining() < 2 {
1945 return Err(Error::Protocol(
1946 "unexpected EOF reading varbinary length".into(),
1947 ));
1948 }
1949 let len = buf.get_u16_le();
1950 if len == 0xFFFF {
1951 SqlValue::Null
1952 } else if buf.remaining() < len as usize {
1953 return Err(Error::Protocol(
1954 "unexpected EOF reading varbinary data".into(),
1955 ));
1956 } else {
1957 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
1958 buf.advance(len as usize);
1959 SqlValue::Binary(data)
1960 }
1961 }
1962
1963 TypeId::Guid => {
1965 if buf.remaining() < 1 {
1966 return Err(Error::Protocol("unexpected EOF reading GUID length".into()));
1967 }
1968 let len = buf.get_u8();
1969 if len == 0 {
1970 SqlValue::Null
1971 } else if len != 16 {
1972 return Err(Error::Protocol(format!("invalid GUID length: {len}")));
1973 } else if buf.remaining() < 16 {
1974 return Err(Error::Protocol("unexpected EOF reading GUID".into()));
1975 } else {
1976 let data = bytes::Bytes::copy_from_slice(&buf[..16]);
1978 buf.advance(16);
1979 SqlValue::Binary(data)
1980 }
1981 }
1982
1983 _ => {
1985 if buf.remaining() < 2 {
1987 return Err(Error::Protocol(format!(
1988 "unexpected EOF reading {:?}",
1989 col.type_id
1990 )));
1991 }
1992 let len = buf.get_u16_le();
1993 if len == 0xFFFF {
1994 SqlValue::Null
1995 } else if buf.remaining() < len as usize {
1996 return Err(Error::Protocol(format!(
1997 "unexpected EOF reading {:?} data",
1998 col.type_id
1999 )));
2000 } else {
2001 let data = bytes::Bytes::copy_from_slice(&buf[..len as usize]);
2002 buf.advance(len as usize);
2003 SqlValue::Binary(data)
2004 }
2005 }
2006 };
2007
2008 Ok(value)
2009 }
2010
2011 fn time_bytes_for_scale(scale: u8) -> usize {
2013 match scale {
2014 0..=2 => 3,
2015 3..=4 => 4,
2016 5..=7 => 5,
2017 _ => 5, }
2019 }
2020
2021 #[cfg(feature = "chrono")]
2023 fn intervals_to_time(intervals: u64, scale: u8) -> chrono::NaiveTime {
2024 let nanos = match scale {
2034 0 => intervals * 1_000_000_000,
2035 1 => intervals * 100_000_000,
2036 2 => intervals * 10_000_000,
2037 3 => intervals * 1_000_000,
2038 4 => intervals * 100_000,
2039 5 => intervals * 10_000,
2040 6 => intervals * 1_000,
2041 7 => intervals * 100,
2042 _ => intervals * 100,
2043 };
2044
2045 let secs = (nanos / 1_000_000_000) as u32;
2046 let nano_part = (nanos % 1_000_000_000) as u32;
2047
2048 chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nano_part)
2049 .unwrap_or_else(|| chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap())
2050 }
2051
2052 async fn read_execute_result(&mut self) -> Result<u64> {
2054 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2055
2056 let message = match connection {
2057 ConnectionHandle::Tls(conn) => conn
2058 .read_message()
2059 .await
2060 .map_err(|e| Error::Protocol(e.to_string()))?,
2061 ConnectionHandle::TlsPrelogin(conn) => conn
2062 .read_message()
2063 .await
2064 .map_err(|e| Error::Protocol(e.to_string()))?,
2065 ConnectionHandle::Plain(conn) => conn
2066 .read_message()
2067 .await
2068 .map_err(|e| Error::Protocol(e.to_string()))?,
2069 }
2070 .ok_or(Error::ConnectionClosed)?;
2071
2072 let mut parser = TokenParser::new(message.payload);
2073 let mut rows_affected = 0u64;
2074 let mut current_metadata: Option<ColMetaData> = None;
2075
2076 loop {
2077 let token = parser
2079 .next_token_with_metadata(current_metadata.as_ref())
2080 .map_err(|e| Error::Protocol(e.to_string()))?;
2081
2082 let Some(token) = token else {
2083 break;
2084 };
2085
2086 match token {
2087 Token::ColMetaData(meta) => {
2088 current_metadata = Some(meta);
2090 }
2091 Token::Row(_) | Token::NbcRow(_) => {
2092 }
2095 Token::Done(done) => {
2096 if done.status.error {
2097 return Err(Error::Query("execution failed".to_string()));
2098 }
2099 if done.status.count {
2100 rows_affected += done.row_count;
2102 }
2103 if !done.status.more {
2106 break;
2107 }
2108 }
2109 Token::DoneProc(done) => {
2110 if done.status.count {
2111 rows_affected += done.row_count;
2112 }
2113 }
2114 Token::DoneInProc(done) => {
2115 if done.status.count {
2116 rows_affected += done.row_count;
2117 }
2118 }
2119 Token::Error(err) => {
2120 return Err(Error::Server {
2121 number: err.number,
2122 state: err.state,
2123 class: err.class,
2124 message: err.message.clone(),
2125 server: if err.server.is_empty() {
2126 None
2127 } else {
2128 Some(err.server.clone())
2129 },
2130 procedure: if err.procedure.is_empty() {
2131 None
2132 } else {
2133 Some(err.procedure.clone())
2134 },
2135 line: err.line as u32,
2136 });
2137 }
2138 Token::Info(info) => {
2139 tracing::info!(
2140 number = info.number,
2141 message = %info.message,
2142 "server info message"
2143 );
2144 }
2145 Token::EnvChange(env) => {
2146 Self::process_transaction_env_change(&env, &mut self.transaction_descriptor);
2150 }
2151 _ => {}
2152 }
2153 }
2154
2155 Ok(rows_affected)
2156 }
2157
2158 async fn read_transaction_begin_result(&mut self) -> Result<u64> {
2164 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2165
2166 let message = match connection {
2167 ConnectionHandle::Tls(conn) => conn
2168 .read_message()
2169 .await
2170 .map_err(|e| Error::Protocol(e.to_string()))?,
2171 ConnectionHandle::TlsPrelogin(conn) => conn
2172 .read_message()
2173 .await
2174 .map_err(|e| Error::Protocol(e.to_string()))?,
2175 ConnectionHandle::Plain(conn) => conn
2176 .read_message()
2177 .await
2178 .map_err(|e| Error::Protocol(e.to_string()))?,
2179 }
2180 .ok_or(Error::ConnectionClosed)?;
2181
2182 let mut parser = TokenParser::new(message.payload);
2183 let mut transaction_descriptor: u64 = 0;
2184
2185 loop {
2186 let token = parser
2187 .next_token()
2188 .map_err(|e| Error::Protocol(e.to_string()))?;
2189
2190 let Some(token) = token else {
2191 break;
2192 };
2193
2194 match token {
2195 Token::EnvChange(env) => {
2196 if env.env_type == EnvChangeType::BeginTransaction {
2197 if let tds_protocol::token::EnvChangeValue::Binary(ref data) = env.new_value
2200 {
2201 if data.len() >= 8 {
2202 transaction_descriptor = u64::from_le_bytes([
2203 data[0], data[1], data[2], data[3], data[4], data[5], data[6],
2204 data[7],
2205 ]);
2206 tracing::debug!(
2207 transaction_descriptor =
2208 format!("0x{:016X}", transaction_descriptor),
2209 "transaction begun"
2210 );
2211 }
2212 }
2213 }
2214 }
2215 Token::Done(done) => {
2216 if done.status.error {
2217 return Err(Error::Query("BEGIN TRANSACTION failed".to_string()));
2218 }
2219 break;
2220 }
2221 Token::Error(err) => {
2222 return Err(Error::Server {
2223 number: err.number,
2224 state: err.state,
2225 class: err.class,
2226 message: err.message.clone(),
2227 server: if err.server.is_empty() {
2228 None
2229 } else {
2230 Some(err.server.clone())
2231 },
2232 procedure: if err.procedure.is_empty() {
2233 None
2234 } else {
2235 Some(err.procedure.clone())
2236 },
2237 line: err.line as u32,
2238 });
2239 }
2240 Token::Info(info) => {
2241 tracing::info!(
2242 number = info.number,
2243 message = %info.message,
2244 "server info message"
2245 );
2246 }
2247 _ => {}
2248 }
2249 }
2250
2251 Ok(transaction_descriptor)
2252 }
2253}
2254
2255impl Client<Ready> {
2256 pub async fn query<'a>(
2281 &'a mut self,
2282 sql: &str,
2283 params: &[&(dyn crate::ToSql + Sync)],
2284 ) -> Result<QueryStream<'a>> {
2285 tracing::debug!(sql = sql, params_count = params.len(), "executing query");
2286
2287 #[cfg(feature = "otel")]
2288 let instrumentation = self.instrumentation.clone();
2289 #[cfg(feature = "otel")]
2290 let mut span = instrumentation.query_span(sql);
2291
2292 let result = async {
2293 if params.is_empty() {
2294 self.send_sql_batch(sql).await?;
2296 } else {
2297 let rpc_params = Self::convert_params(params)?;
2299 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2300 self.send_rpc(&rpc).await?;
2301 }
2302
2303 self.read_query_response().await
2305 }
2306 .await;
2307
2308 #[cfg(feature = "otel")]
2309 match &result {
2310 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2311 Err(e) => InstrumentationContext::record_error(&mut span, e),
2312 }
2313
2314 #[cfg(feature = "otel")]
2316 drop(span);
2317
2318 let (columns, rows) = result?;
2319 Ok(QueryStream::new(columns, rows))
2320 }
2321
2322 pub async fn query_with_timeout<'a>(
2349 &'a mut self,
2350 sql: &str,
2351 params: &[&(dyn crate::ToSql + Sync)],
2352 timeout_duration: std::time::Duration,
2353 ) -> Result<QueryStream<'a>> {
2354 timeout(timeout_duration, self.query(sql, params))
2355 .await
2356 .map_err(|_| Error::CommandTimeout)?
2357 }
2358
2359 pub async fn query_multiple<'a>(
2386 &'a mut self,
2387 sql: &str,
2388 params: &[&(dyn crate::ToSql + Sync)],
2389 ) -> Result<MultiResultStream<'a>> {
2390 tracing::debug!(
2391 sql = sql,
2392 params_count = params.len(),
2393 "executing multi-result query"
2394 );
2395
2396 if params.is_empty() {
2397 self.send_sql_batch(sql).await?;
2399 } else {
2400 let rpc_params = Self::convert_params(params)?;
2402 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2403 self.send_rpc(&rpc).await?;
2404 }
2405
2406 let result_sets = self.read_multi_result_response().await?;
2408 Ok(MultiResultStream::new(result_sets))
2409 }
2410
2411 async fn read_multi_result_response(&mut self) -> Result<Vec<crate::stream::ResultSet>> {
2413 let connection = self.connection.as_mut().ok_or(Error::ConnectionClosed)?;
2414
2415 let message = match connection {
2416 ConnectionHandle::Tls(conn) => conn
2417 .read_message()
2418 .await
2419 .map_err(|e| Error::Protocol(e.to_string()))?,
2420 ConnectionHandle::TlsPrelogin(conn) => conn
2421 .read_message()
2422 .await
2423 .map_err(|e| Error::Protocol(e.to_string()))?,
2424 ConnectionHandle::Plain(conn) => conn
2425 .read_message()
2426 .await
2427 .map_err(|e| Error::Protocol(e.to_string()))?,
2428 }
2429 .ok_or(Error::ConnectionClosed)?;
2430
2431 let mut parser = TokenParser::new(message.payload);
2432 let mut result_sets: Vec<crate::stream::ResultSet> = Vec::new();
2433 let mut current_columns: Vec<crate::row::Column> = Vec::new();
2434 let mut current_rows: Vec<crate::row::Row> = Vec::new();
2435 let mut protocol_metadata: Option<ColMetaData> = None;
2436
2437 loop {
2438 let token = parser
2439 .next_token_with_metadata(protocol_metadata.as_ref())
2440 .map_err(|e| Error::Protocol(e.to_string()))?;
2441
2442 let Some(token) = token else {
2443 break;
2444 };
2445
2446 match token {
2447 Token::ColMetaData(meta) => {
2448 if !current_columns.is_empty() {
2450 result_sets.push(crate::stream::ResultSet::new(
2451 std::mem::take(&mut current_columns),
2452 std::mem::take(&mut current_rows),
2453 ));
2454 }
2455
2456 current_columns = meta
2458 .columns
2459 .iter()
2460 .enumerate()
2461 .map(|(i, col)| {
2462 let type_name = format!("{:?}", col.type_id);
2463 let mut column = crate::row::Column::new(&col.name, i, type_name)
2464 .with_nullable(col.flags & 0x01 != 0);
2465
2466 if let Some(max_len) = col.type_info.max_length {
2467 column = column.with_max_length(max_len);
2468 }
2469 if let (Some(prec), Some(scale)) =
2470 (col.type_info.precision, col.type_info.scale)
2471 {
2472 column = column.with_precision_scale(prec, scale);
2473 }
2474 column
2475 })
2476 .collect();
2477
2478 tracing::debug!(
2479 columns = current_columns.len(),
2480 result_set = result_sets.len(),
2481 "received column metadata for result set"
2482 );
2483 protocol_metadata = Some(meta);
2484 }
2485 Token::Row(raw_row) => {
2486 if let Some(ref meta) = protocol_metadata {
2487 let row = Self::convert_raw_row(&raw_row, meta, ¤t_columns)?;
2488 current_rows.push(row);
2489 }
2490 }
2491 Token::NbcRow(nbc_row) => {
2492 if let Some(ref meta) = protocol_metadata {
2493 let row = Self::convert_nbc_row(&nbc_row, meta, ¤t_columns)?;
2494 current_rows.push(row);
2495 }
2496 }
2497 Token::Error(err) => {
2498 return Err(Error::Server {
2499 number: err.number,
2500 state: err.state,
2501 class: err.class,
2502 message: err.message.clone(),
2503 server: if err.server.is_empty() {
2504 None
2505 } else {
2506 Some(err.server.clone())
2507 },
2508 procedure: if err.procedure.is_empty() {
2509 None
2510 } else {
2511 Some(err.procedure.clone())
2512 },
2513 line: err.line as u32,
2514 });
2515 }
2516 Token::Done(done) => {
2517 if done.status.error {
2518 return Err(Error::Query("query failed".to_string()));
2519 }
2520
2521 if !current_columns.is_empty() {
2523 result_sets.push(crate::stream::ResultSet::new(
2524 std::mem::take(&mut current_columns),
2525 std::mem::take(&mut current_rows),
2526 ));
2527 protocol_metadata = None;
2528 }
2529
2530 if !done.status.more {
2532 tracing::debug!(result_sets = result_sets.len(), "all result sets parsed");
2533 break;
2534 }
2535 }
2536 Token::DoneInProc(done) => {
2537 if done.status.error {
2538 return Err(Error::Query("query failed".to_string()));
2539 }
2540
2541 if !current_columns.is_empty() {
2543 result_sets.push(crate::stream::ResultSet::new(
2544 std::mem::take(&mut current_columns),
2545 std::mem::take(&mut current_rows),
2546 ));
2547 protocol_metadata = None;
2548 }
2549
2550 if !done.status.more {
2552 }
2554 }
2555 Token::DoneProc(done) => {
2556 if done.status.error {
2557 return Err(Error::Query("query failed".to_string()));
2558 }
2559 }
2561 Token::Info(info) => {
2562 tracing::debug!(
2563 number = info.number,
2564 message = %info.message,
2565 "server info message"
2566 );
2567 }
2568 _ => {}
2569 }
2570 }
2571
2572 if !current_columns.is_empty() {
2574 result_sets.push(crate::stream::ResultSet::new(current_columns, current_rows));
2575 }
2576
2577 Ok(result_sets)
2578 }
2579
2580 pub async fn execute(
2584 &mut self,
2585 sql: &str,
2586 params: &[&(dyn crate::ToSql + Sync)],
2587 ) -> Result<u64> {
2588 tracing::debug!(
2589 sql = sql,
2590 params_count = params.len(),
2591 "executing statement"
2592 );
2593
2594 #[cfg(feature = "otel")]
2595 let instrumentation = self.instrumentation.clone();
2596 #[cfg(feature = "otel")]
2597 let mut span = instrumentation.query_span(sql);
2598
2599 let result = async {
2600 if params.is_empty() {
2601 self.send_sql_batch(sql).await?;
2603 } else {
2604 let rpc_params = Self::convert_params(params)?;
2606 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2607 self.send_rpc(&rpc).await?;
2608 }
2609
2610 self.read_execute_result().await
2612 }
2613 .await;
2614
2615 #[cfg(feature = "otel")]
2616 match &result {
2617 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
2618 Err(e) => InstrumentationContext::record_error(&mut span, e),
2619 }
2620
2621 #[cfg(feature = "otel")]
2623 drop(span);
2624
2625 result
2626 }
2627
2628 pub async fn execute_with_timeout(
2655 &mut self,
2656 sql: &str,
2657 params: &[&(dyn crate::ToSql + Sync)],
2658 timeout_duration: std::time::Duration,
2659 ) -> Result<u64> {
2660 timeout(timeout_duration, self.execute(sql, params))
2661 .await
2662 .map_err(|_| Error::CommandTimeout)?
2663 }
2664
2665 pub async fn begin_transaction(mut self) -> Result<Client<InTransaction>> {
2672 tracing::debug!("beginning transaction");
2673
2674 #[cfg(feature = "otel")]
2675 let instrumentation = self.instrumentation.clone();
2676 #[cfg(feature = "otel")]
2677 let mut span = instrumentation.transaction_span("BEGIN");
2678
2679 let result = async {
2681 self.send_sql_batch("BEGIN TRANSACTION").await?;
2682 self.read_transaction_begin_result().await
2683 }
2684 .await;
2685
2686 #[cfg(feature = "otel")]
2687 match &result {
2688 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2689 Err(e) => InstrumentationContext::record_error(&mut span, e),
2690 }
2691
2692 #[cfg(feature = "otel")]
2694 drop(span);
2695
2696 let transaction_descriptor = result?;
2697
2698 Ok(Client {
2699 config: self.config,
2700 _state: PhantomData,
2701 connection: self.connection,
2702 server_version: self.server_version,
2703 current_database: self.current_database,
2704 statement_cache: self.statement_cache,
2705 transaction_descriptor, #[cfg(feature = "otel")]
2707 instrumentation: self.instrumentation,
2708 })
2709 }
2710
2711 pub async fn begin_transaction_with_isolation(
2726 mut self,
2727 isolation_level: crate::transaction::IsolationLevel,
2728 ) -> Result<Client<InTransaction>> {
2729 tracing::debug!(
2730 isolation_level = %isolation_level.name(),
2731 "beginning transaction with isolation level"
2732 );
2733
2734 #[cfg(feature = "otel")]
2735 let instrumentation = self.instrumentation.clone();
2736 #[cfg(feature = "otel")]
2737 let mut span = instrumentation.transaction_span("BEGIN");
2738
2739 let result = async {
2741 self.send_sql_batch(isolation_level.as_sql()).await?;
2742 self.read_execute_result().await?;
2743
2744 self.send_sql_batch("BEGIN TRANSACTION").await?;
2746 self.read_transaction_begin_result().await
2747 }
2748 .await;
2749
2750 #[cfg(feature = "otel")]
2751 match &result {
2752 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2753 Err(e) => InstrumentationContext::record_error(&mut span, e),
2754 }
2755
2756 #[cfg(feature = "otel")]
2757 drop(span);
2758
2759 let transaction_descriptor = result?;
2760
2761 Ok(Client {
2762 config: self.config,
2763 _state: PhantomData,
2764 connection: self.connection,
2765 server_version: self.server_version,
2766 current_database: self.current_database,
2767 statement_cache: self.statement_cache,
2768 transaction_descriptor,
2769 #[cfg(feature = "otel")]
2770 instrumentation: self.instrumentation,
2771 })
2772 }
2773
2774 pub async fn simple_query(&mut self, sql: &str) -> Result<()> {
2779 tracing::debug!(sql = sql, "executing simple query");
2780
2781 self.send_sql_batch(sql).await?;
2783
2784 let _ = self.read_execute_result().await?;
2786
2787 Ok(())
2788 }
2789
2790 pub async fn close(self) -> Result<()> {
2792 tracing::debug!("closing connection");
2793 Ok(())
2794 }
2795
2796 #[must_use]
2798 pub fn database(&self) -> Option<&str> {
2799 self.config.database.as_deref()
2800 }
2801
2802 #[must_use]
2804 pub fn host(&self) -> &str {
2805 &self.config.host
2806 }
2807
2808 #[must_use]
2810 pub fn port(&self) -> u16 {
2811 self.config.port
2812 }
2813
2814 #[must_use]
2833 pub fn is_in_transaction(&self) -> bool {
2834 self.transaction_descriptor != 0
2835 }
2836
2837 #[must_use]
2859 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
2860 let connection = self
2861 .connection
2862 .as_ref()
2863 .expect("connection should be present");
2864 match connection {
2865 ConnectionHandle::Tls(conn) => {
2866 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
2867 }
2868 ConnectionHandle::TlsPrelogin(conn) => {
2869 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
2870 }
2871 ConnectionHandle::Plain(conn) => {
2872 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
2873 }
2874 }
2875 }
2876}
2877
2878impl Client<InTransaction> {
2879 pub async fn query<'a>(
2883 &'a mut self,
2884 sql: &str,
2885 params: &[&(dyn crate::ToSql + Sync)],
2886 ) -> Result<QueryStream<'a>> {
2887 tracing::debug!(
2888 sql = sql,
2889 params_count = params.len(),
2890 "executing query in transaction"
2891 );
2892
2893 #[cfg(feature = "otel")]
2894 let instrumentation = self.instrumentation.clone();
2895 #[cfg(feature = "otel")]
2896 let mut span = instrumentation.query_span(sql);
2897
2898 let result = async {
2899 if params.is_empty() {
2900 self.send_sql_batch(sql).await?;
2902 } else {
2903 let rpc_params = Self::convert_params(params)?;
2905 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2906 self.send_rpc(&rpc).await?;
2907 }
2908
2909 self.read_query_response().await
2911 }
2912 .await;
2913
2914 #[cfg(feature = "otel")]
2915 match &result {
2916 Ok(_) => InstrumentationContext::record_success(&mut span, None),
2917 Err(e) => InstrumentationContext::record_error(&mut span, e),
2918 }
2919
2920 #[cfg(feature = "otel")]
2922 drop(span);
2923
2924 let (columns, rows) = result?;
2925 Ok(QueryStream::new(columns, rows))
2926 }
2927
2928 pub async fn execute(
2932 &mut self,
2933 sql: &str,
2934 params: &[&(dyn crate::ToSql + Sync)],
2935 ) -> Result<u64> {
2936 tracing::debug!(
2937 sql = sql,
2938 params_count = params.len(),
2939 "executing statement in transaction"
2940 );
2941
2942 #[cfg(feature = "otel")]
2943 let instrumentation = self.instrumentation.clone();
2944 #[cfg(feature = "otel")]
2945 let mut span = instrumentation.query_span(sql);
2946
2947 let result = async {
2948 if params.is_empty() {
2949 self.send_sql_batch(sql).await?;
2951 } else {
2952 let rpc_params = Self::convert_params(params)?;
2954 let rpc = RpcRequest::execute_sql(sql, rpc_params);
2955 self.send_rpc(&rpc).await?;
2956 }
2957
2958 self.read_execute_result().await
2960 }
2961 .await;
2962
2963 #[cfg(feature = "otel")]
2964 match &result {
2965 Ok(rows) => InstrumentationContext::record_success(&mut span, Some(*rows)),
2966 Err(e) => InstrumentationContext::record_error(&mut span, e),
2967 }
2968
2969 #[cfg(feature = "otel")]
2971 drop(span);
2972
2973 result
2974 }
2975
2976 pub async fn query_with_timeout<'a>(
2980 &'a mut self,
2981 sql: &str,
2982 params: &[&(dyn crate::ToSql + Sync)],
2983 timeout_duration: std::time::Duration,
2984 ) -> Result<QueryStream<'a>> {
2985 timeout(timeout_duration, self.query(sql, params))
2986 .await
2987 .map_err(|_| Error::CommandTimeout)?
2988 }
2989
2990 pub async fn execute_with_timeout(
2994 &mut self,
2995 sql: &str,
2996 params: &[&(dyn crate::ToSql + Sync)],
2997 timeout_duration: std::time::Duration,
2998 ) -> Result<u64> {
2999 timeout(timeout_duration, self.execute(sql, params))
3000 .await
3001 .map_err(|_| Error::CommandTimeout)?
3002 }
3003
3004 pub async fn commit(mut self) -> Result<Client<Ready>> {
3008 tracing::debug!("committing transaction");
3009
3010 #[cfg(feature = "otel")]
3011 let instrumentation = self.instrumentation.clone();
3012 #[cfg(feature = "otel")]
3013 let mut span = instrumentation.transaction_span("COMMIT");
3014
3015 let result = async {
3017 self.send_sql_batch("COMMIT TRANSACTION").await?;
3018 self.read_execute_result().await
3019 }
3020 .await;
3021
3022 #[cfg(feature = "otel")]
3023 match &result {
3024 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3025 Err(e) => InstrumentationContext::record_error(&mut span, e),
3026 }
3027
3028 #[cfg(feature = "otel")]
3030 drop(span);
3031
3032 result?;
3033
3034 Ok(Client {
3035 config: self.config,
3036 _state: PhantomData,
3037 connection: self.connection,
3038 server_version: self.server_version,
3039 current_database: self.current_database,
3040 statement_cache: self.statement_cache,
3041 transaction_descriptor: 0, #[cfg(feature = "otel")]
3043 instrumentation: self.instrumentation,
3044 })
3045 }
3046
3047 pub async fn rollback(mut self) -> Result<Client<Ready>> {
3051 tracing::debug!("rolling back transaction");
3052
3053 #[cfg(feature = "otel")]
3054 let instrumentation = self.instrumentation.clone();
3055 #[cfg(feature = "otel")]
3056 let mut span = instrumentation.transaction_span("ROLLBACK");
3057
3058 let result = async {
3060 self.send_sql_batch("ROLLBACK TRANSACTION").await?;
3061 self.read_execute_result().await
3062 }
3063 .await;
3064
3065 #[cfg(feature = "otel")]
3066 match &result {
3067 Ok(_) => InstrumentationContext::record_success(&mut span, None),
3068 Err(e) => InstrumentationContext::record_error(&mut span, e),
3069 }
3070
3071 #[cfg(feature = "otel")]
3073 drop(span);
3074
3075 result?;
3076
3077 Ok(Client {
3078 config: self.config,
3079 _state: PhantomData,
3080 connection: self.connection,
3081 server_version: self.server_version,
3082 current_database: self.current_database,
3083 statement_cache: self.statement_cache,
3084 transaction_descriptor: 0, #[cfg(feature = "otel")]
3086 instrumentation: self.instrumentation,
3087 })
3088 }
3089
3090 pub async fn save_point(&mut self, name: &str) -> Result<SavePoint> {
3107 validate_identifier(name)?;
3108 tracing::debug!(name = name, "creating savepoint");
3109
3110 let sql = format!("SAVE TRANSACTION {}", name);
3113 self.send_sql_batch(&sql).await?;
3114 self.read_execute_result().await?;
3115
3116 Ok(SavePoint::new(name.to_string()))
3117 }
3118
3119 pub async fn rollback_to(&mut self, savepoint: &SavePoint) -> Result<()> {
3134 tracing::debug!(name = savepoint.name(), "rolling back to savepoint");
3135
3136 let sql = format!("ROLLBACK TRANSACTION {}", savepoint.name());
3139 self.send_sql_batch(&sql).await?;
3140 self.read_execute_result().await?;
3141
3142 Ok(())
3143 }
3144
3145 pub async fn release_savepoint(&mut self, savepoint: SavePoint) -> Result<()> {
3151 tracing::debug!(name = savepoint.name(), "releasing savepoint");
3152
3153 drop(savepoint);
3157 Ok(())
3158 }
3159
3160 #[must_use]
3164 pub fn cancel_handle(&self) -> crate::cancel::CancelHandle {
3165 let connection = self
3166 .connection
3167 .as_ref()
3168 .expect("connection should be present");
3169 match connection {
3170 ConnectionHandle::Tls(conn) => {
3171 crate::cancel::CancelHandle::from_tls(conn.cancel_handle())
3172 }
3173 ConnectionHandle::TlsPrelogin(conn) => {
3174 crate::cancel::CancelHandle::from_tls_prelogin(conn.cancel_handle())
3175 }
3176 ConnectionHandle::Plain(conn) => {
3177 crate::cancel::CancelHandle::from_plain(conn.cancel_handle())
3178 }
3179 }
3180 }
3181}
3182
3183fn validate_identifier(name: &str) -> Result<()> {
3185 use once_cell::sync::Lazy;
3186 use regex::Regex;
3187
3188 static IDENTIFIER_RE: Lazy<Regex> =
3189 Lazy::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_@#$]{0,127}$").unwrap());
3190
3191 if name.is_empty() {
3192 return Err(Error::InvalidIdentifier(
3193 "identifier cannot be empty".into(),
3194 ));
3195 }
3196
3197 if !IDENTIFIER_RE.is_match(name) {
3198 return Err(Error::InvalidIdentifier(format!(
3199 "invalid identifier '{}': must start with letter/underscore, \
3200 contain only alphanumerics/_/@/#/$, and be 1-128 characters",
3201 name
3202 )));
3203 }
3204
3205 Ok(())
3206}
3207
3208impl<S: ConnectionState> std::fmt::Debug for Client<S> {
3209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3210 f.debug_struct("Client")
3211 .field("host", &self.config.host)
3212 .field("port", &self.config.port)
3213 .field("database", &self.config.database)
3214 .finish()
3215 }
3216}
3217
3218#[cfg(test)]
3219#[allow(clippy::unwrap_used)]
3220mod tests {
3221 use super::*;
3222
3223 #[test]
3224 fn test_validate_identifier_valid() {
3225 assert!(validate_identifier("my_table").is_ok());
3226 assert!(validate_identifier("Table123").is_ok());
3227 assert!(validate_identifier("_private").is_ok());
3228 assert!(validate_identifier("sp_test").is_ok());
3229 }
3230
3231 #[test]
3232 fn test_validate_identifier_invalid() {
3233 assert!(validate_identifier("").is_err());
3234 assert!(validate_identifier("123abc").is_err());
3235 assert!(validate_identifier("table-name").is_err());
3236 assert!(validate_identifier("table name").is_err());
3237 assert!(validate_identifier("table;DROP TABLE users").is_err());
3238 }
3239}