1use super::state::ConnectionState;
4use super::tls::SslMode;
5use super::transport::Transport;
6use crate::auth::scram::ChannelBinding;
7use crate::auth::ScramClient;
8use crate::protocol::{
9 decode_message, encode_message, AuthenticationMessage, BackendMessage, FrontendMessage,
10};
11use crate::{Error, Result};
12use bytes::{Buf, BytesMut};
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::Duration;
16use tracing::Instrument;
17
18static CHUNK_COUNT: AtomicU64 = AtomicU64::new(0);
21
22#[derive(Debug, Clone)]
27pub struct ConnectionConfig {
28 pub database: String,
30 pub user: String,
32 pub password: Option<String>,
34 pub params: HashMap<String, String>,
36 pub connect_timeout: Option<Duration>,
38 pub statement_timeout: Option<Duration>,
40 pub keepalive_idle: Option<Duration>,
42 pub application_name: Option<String>,
44 pub extra_float_digits: Option<i32>,
46 pub sslmode: SslMode,
48}
49
50impl ConnectionConfig {
51 pub fn new(database: impl Into<String>, user: impl Into<String>) -> Self {
68 Self {
69 database: database.into(),
70 user: user.into(),
71 password: None,
72 params: HashMap::new(),
73 connect_timeout: None,
74 statement_timeout: None,
75 keepalive_idle: None,
76 application_name: None,
77 extra_float_digits: None,
78 sslmode: SslMode::default(),
79 }
80 }
81
82 pub fn builder(
95 database: impl Into<String>,
96 user: impl Into<String>,
97 ) -> ConnectionConfigBuilder {
98 ConnectionConfigBuilder {
99 database: database.into(),
100 user: user.into(),
101 password: None,
102 params: HashMap::new(),
103 connect_timeout: None,
104 statement_timeout: None,
105 keepalive_idle: None,
106 application_name: None,
107 extra_float_digits: None,
108 sslmode: SslMode::default(),
109 }
110 }
111
112 pub fn password(mut self, password: impl Into<String>) -> Self {
114 self.password = Some(password.into());
115 self
116 }
117
118 pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
120 self.params.insert(key.into(), value.into());
121 self
122 }
123}
124
125#[derive(Debug, Clone)]
141pub struct ConnectionConfigBuilder {
142 database: String,
143 user: String,
144 password: Option<String>,
145 params: HashMap<String, String>,
146 connect_timeout: Option<Duration>,
147 statement_timeout: Option<Duration>,
148 keepalive_idle: Option<Duration>,
149 application_name: Option<String>,
150 extra_float_digits: Option<i32>,
151 sslmode: SslMode,
152}
153
154impl ConnectionConfigBuilder {
155 pub fn password(mut self, password: impl Into<String>) -> Self {
157 self.password = Some(password.into());
158 self
159 }
160
161 pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
163 self.params.insert(key.into(), value.into());
164 self
165 }
166
167 pub fn connect_timeout(mut self, duration: Duration) -> Self {
175 self.connect_timeout = Some(duration);
176 self
177 }
178
179 pub fn statement_timeout(mut self, duration: Duration) -> Self {
187 self.statement_timeout = Some(duration);
188 self
189 }
190
191 pub fn keepalive_idle(mut self, duration: Duration) -> Self {
199 self.keepalive_idle = Some(duration);
200 self
201 }
202
203 pub fn application_name(mut self, name: impl Into<String>) -> Self {
211 self.application_name = Some(name.into());
212 self
213 }
214
215 pub fn extra_float_digits(mut self, digits: i32) -> Self {
223 self.extra_float_digits = Some(digits);
224 self
225 }
226
227 pub fn sslmode(mut self, mode: SslMode) -> Self {
229 self.sslmode = mode;
230 self
231 }
232
233 pub fn build(self) -> ConnectionConfig {
235 ConnectionConfig {
236 database: self.database,
237 user: self.user,
238 password: self.password,
239 params: self.params,
240 connect_timeout: self.connect_timeout,
241 statement_timeout: self.statement_timeout,
242 keepalive_idle: self.keepalive_idle,
243 application_name: self.application_name,
244 extra_float_digits: self.extra_float_digits,
245 sslmode: self.sslmode,
246 }
247 }
248}
249
250pub struct Connection {
252 transport: Option<Transport>,
253 state: ConnectionState,
254 read_buf: BytesMut,
255 process_id: Option<i32>,
256 secret_key: Option<i32>,
257}
258
259impl Connection {
260 pub fn new(transport: Transport) -> Self {
262 Self {
263 transport: Some(transport),
264 state: ConnectionState::Initial,
265 read_buf: BytesMut::with_capacity(8192),
266 process_id: None,
267 secret_key: None,
268 }
269 }
270
271 pub fn state(&self) -> ConnectionState {
273 self.state
274 }
275
276 async fn negotiate_tls(
282 &mut self,
283 tls_config: &super::TlsConfig,
284 hostname: &str,
285 sslmode: SslMode,
286 ) -> Result<()> {
287 self.state.transition(ConnectionState::NegotiatingTls)?;
288
289 let ssl_request = FrontendMessage::SslRequest;
291 self.send_message(&ssl_request).await?;
292
293 let transport = self
295 .transport
296 .as_mut()
297 .expect("transport taken during TLS upgrade");
298 let n = transport.read_buf(&mut self.read_buf).await?;
299 if n == 0 {
300 return Err(Error::ConnectionClosed);
301 }
302
303 let response = self.read_buf[0];
304 self.read_buf.advance(1);
305
306 match response {
307 b'S' => {
308 tracing::debug!("server accepted TLS, upgrading connection");
309 let transport = self.transport.take().expect("transport not available");
311 self.transport = Some(transport.upgrade_to_tls(tls_config, hostname).await?);
312 tracing::info!("TLS connection established");
313 Ok(())
314 }
315 b'N' => {
316 tracing::debug!("server rejected TLS");
317 Err(Error::Config(format!(
318 "server does not support TLS (sslmode={})",
319 sslmode
320 )))
321 }
322 other => Err(Error::Protocol(format!(
323 "unexpected SSLRequest response byte: 0x{:02X}",
324 other
325 ))),
326 }
327 }
328
329 pub async fn startup(
331 &mut self,
332 config: &ConnectionConfig,
333 tls_config: Option<&super::TlsConfig>,
334 hostname: Option<&str>,
335 ) -> Result<()> {
336 async {
337 if config.sslmode != SslMode::Disable {
339 let tls = tls_config.ok_or_else(|| {
340 Error::Config(format!(
341 "sslmode={} requires TlsConfig but none was provided",
342 config.sslmode
343 ))
344 })?;
345 let host = hostname
346 .ok_or_else(|| Error::Config("TLS negotiation requires a hostname".into()))?;
347 self.negotiate_tls(tls, host, config.sslmode).await?;
348 }
349
350 self.state.transition(ConnectionState::AwaitingAuth)?;
351
352 let mut params = vec![
354 ("user".to_string(), config.user.clone()),
355 ("database".to_string(), config.database.clone()),
356 ];
357
358 if let Some(app_name) = &config.application_name {
360 params.push(("application_name".to_string(), app_name.clone()));
361 }
362
363 if let Some(timeout) = config.statement_timeout {
365 params.push((
366 "statement_timeout".to_string(),
367 timeout.as_millis().to_string(),
368 ));
369 }
370
371 if let Some(digits) = config.extra_float_digits {
373 params.push(("extra_float_digits".to_string(), digits.to_string()));
374 }
375
376 for (k, v) in &config.params {
378 params.push((k.clone(), v.clone()));
379 }
380
381 let startup = FrontendMessage::Startup {
383 version: crate::protocol::constants::PROTOCOL_VERSION,
384 params,
385 };
386 self.send_message(&startup).await?;
387
388 self.state.transition(ConnectionState::Authenticating)?;
390 self.authenticate(config).await?;
391
392 self.state.transition(ConnectionState::Idle)?;
393 tracing::info!("startup complete");
394 Ok(())
395 }
396 .instrument(tracing::info_span!(
397 "startup",
398 user = %config.user,
399 database = %config.database
400 ))
401 .await
402 }
403
404 async fn authenticate(&mut self, config: &ConnectionConfig) -> Result<()> {
406 let auth_start = std::time::Instant::now();
407 let mut auth_mechanism = "unknown";
408
409 loop {
410 let msg = self.receive_message().await?;
411
412 match msg {
413 BackendMessage::Authentication(auth) => match auth {
414 AuthenticationMessage::Ok => {
415 tracing::debug!("authentication successful");
416 crate::metrics::counters::auth_successful(auth_mechanism);
417 crate::metrics::histograms::auth_duration(
418 auth_mechanism,
419 auth_start.elapsed().as_millis() as u64,
420 );
421 }
423 AuthenticationMessage::CleartextPassword => {
424 auth_mechanism = crate::metrics::labels::MECHANISM_CLEARTEXT;
425 crate::metrics::counters::auth_attempted(auth_mechanism);
426
427 let password = config
428 .password
429 .as_ref()
430 .ok_or_else(|| Error::Authentication("password required".into()))?;
431 let pwd_msg = FrontendMessage::Password(password.clone());
432 self.send_message(&pwd_msg).await?;
433 }
434 AuthenticationMessage::Md5Password { .. } => {
435 return Err(Error::Authentication(
436 "MD5 authentication not supported. Use SCRAM-SHA-256 or cleartext password".into(),
437 ));
438 }
439 AuthenticationMessage::Sasl { mechanisms } => {
440 auth_mechanism = crate::metrics::labels::MECHANISM_SCRAM;
441 crate::metrics::counters::auth_attempted(auth_mechanism);
442 self.handle_sasl(&mechanisms, config).await?;
443 }
444 AuthenticationMessage::SaslContinue { .. } => {
445 return Err(Error::Protocol(
446 "unexpected SaslContinue outside of SASL flow".into(),
447 ));
448 }
449 AuthenticationMessage::SaslFinal { .. } => {
450 return Err(Error::Protocol(
451 "unexpected SaslFinal outside of SASL flow".into(),
452 ));
453 }
454 },
455 BackendMessage::BackendKeyData {
456 process_id,
457 secret_key,
458 } => {
459 self.process_id = Some(process_id);
460 self.secret_key = Some(secret_key);
461 }
462 BackendMessage::ParameterStatus { name, value } => {
463 tracing::debug!("parameter status: {} = {}", name, value);
464 }
465 BackendMessage::ReadyForQuery { status: _ } => {
466 break;
467 }
468 BackendMessage::ErrorResponse(err) => {
469 crate::metrics::counters::auth_failed(auth_mechanism, "server_error");
470 return Err(Error::Authentication(err.to_string()));
471 }
472 _ => {
473 return Err(Error::Protocol(format!(
474 "unexpected message during auth: {:?}",
475 msg
476 )));
477 }
478 }
479 }
480
481 Ok(())
482 }
483
484 async fn handle_sasl(
486 &mut self,
487 mechanisms: &[String],
488 config: &ConnectionConfig,
489 ) -> Result<()> {
490 let channel_binding_data = self
492 .transport
493 .as_ref()
494 .and_then(|t| t.channel_binding_data());
495
496 let (mechanism, channel_binding) = if mechanisms.contains(&"SCRAM-SHA-256-PLUS".to_string())
497 {
498 if let Some(cb_data) = channel_binding_data {
499 (
500 "SCRAM-SHA-256-PLUS",
501 ChannelBinding::TlsServerEndPoint(cb_data),
502 )
503 } else {
504 ("SCRAM-SHA-256", ChannelBinding::None)
505 }
506 } else if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
507 ("SCRAM-SHA-256", ChannelBinding::None)
508 } else {
509 return Err(Error::Authentication(format!(
510 "server does not support SCRAM-SHA-256. Available: {}",
511 mechanisms.join(", ")
512 )));
513 };
514
515 let password = config.password.as_ref().ok_or_else(|| {
517 Error::Authentication("password required for SCRAM authentication".into())
518 })?;
519
520 let mut scram = ScramClient::with_channel_binding(
522 config.user.clone(),
523 password.clone(),
524 channel_binding,
525 );
526 tracing::debug!("initiating {} authentication", mechanism);
527
528 let client_first = scram.client_first();
530 let msg = FrontendMessage::SaslInitialResponse {
531 mechanism: mechanism.to_string(),
532 data: client_first.into_bytes(),
533 };
534 self.send_message(&msg).await?;
535
536 let server_first_msg = self.receive_message().await?;
538 let server_first_data = match server_first_msg {
539 BackendMessage::Authentication(AuthenticationMessage::SaslContinue { data }) => data,
540 BackendMessage::ErrorResponse(err) => {
541 return Err(Error::Authentication(format!("SASL server error: {}", err)));
542 }
543 _ => {
544 return Err(Error::Protocol(
545 "expected SaslContinue message during SASL authentication".into(),
546 ));
547 }
548 };
549
550 let server_first = String::from_utf8(server_first_data).map_err(|e| {
551 Error::Authentication(format!("invalid UTF-8 in server first message: {}", e))
552 })?;
553
554 tracing::debug!("received SCRAM server first message");
555
556 let (client_final, scram_state) = scram
558 .client_final(&server_first)
559 .map_err(|e| Error::Authentication(format!("SCRAM error: {}", e)))?;
560
561 let msg = FrontendMessage::SaslResponse {
563 data: client_final.into_bytes(),
564 };
565 self.send_message(&msg).await?;
566
567 let server_final_msg = self.receive_message().await?;
569 let server_final_data = match server_final_msg {
570 BackendMessage::Authentication(AuthenticationMessage::SaslFinal { data }) => data,
571 BackendMessage::ErrorResponse(err) => {
572 return Err(Error::Authentication(format!("SASL server error: {}", err)));
573 }
574 _ => {
575 return Err(Error::Protocol(
576 "expected SaslFinal message during SASL authentication".into(),
577 ));
578 }
579 };
580
581 let server_final = String::from_utf8(server_final_data).map_err(|e| {
582 Error::Authentication(format!("invalid UTF-8 in server final message: {}", e))
583 })?;
584
585 scram
587 .verify_server_final(&server_final, &scram_state)
588 .map_err(|e| Error::Authentication(format!("SCRAM verification failed: {}", e)))?;
589
590 tracing::debug!("SCRAM-SHA-256 authentication successful");
591 Ok(())
592 }
593
594 pub async fn simple_query(&mut self, query: &str) -> Result<Vec<BackendMessage>> {
596 if self.state != ConnectionState::Idle {
597 return Err(Error::ConnectionBusy(format!(
598 "connection in state: {}",
599 self.state
600 )));
601 }
602
603 self.state.transition(ConnectionState::QueryInProgress)?;
604
605 let query_msg = FrontendMessage::Query(query.to_string());
606 self.send_message(&query_msg).await?;
607
608 self.state.transition(ConnectionState::ReadingResults)?;
609
610 let mut messages = Vec::new();
611
612 loop {
613 let msg = self.receive_message().await?;
614 let is_ready = matches!(msg, BackendMessage::ReadyForQuery { .. });
615 messages.push(msg);
616
617 if is_ready {
618 break;
619 }
620 }
621
622 self.state.transition(ConnectionState::Idle)?;
623 Ok(messages)
624 }
625
626 async fn send_message(&mut self, msg: &FrontendMessage) -> Result<()> {
628 let buf = encode_message(msg)?;
629 let transport = self.transport.as_mut().expect("transport not available");
630 transport.write_all(&buf).await?;
631 transport.flush().await?;
632 Ok(())
633 }
634
635 async fn receive_message(&mut self) -> Result<BackendMessage> {
637 loop {
638 if let Ok((msg, consumed)) = decode_message(&mut self.read_buf) {
640 self.read_buf.advance(consumed);
641 return Ok(msg);
642 }
643
644 let transport = self.transport.as_mut().expect("transport not available");
646 let n = transport.read_buf(&mut self.read_buf).await?;
647 if n == 0 {
648 return Err(Error::ConnectionClosed);
649 }
650 }
651 }
652
653 pub async fn close(mut self) -> Result<()> {
655 self.state.transition(ConnectionState::Closed)?;
656 let _ = self.send_message(&FrontendMessage::Terminate).await;
657 let transport = self.transport.as_mut().expect("transport not available");
658 transport.shutdown().await?;
659 Ok(())
660 }
661
662 #[allow(clippy::too_many_arguments)]
667 pub async fn streaming_query(
668 mut self,
669 query: &str,
670 chunk_size: usize,
671 max_memory: Option<usize>,
672 soft_limit_warn_threshold: Option<f32>,
673 soft_limit_fail_threshold: Option<f32>,
674 enable_adaptive_chunking: bool,
675 adaptive_min_chunk_size: Option<usize>,
676 adaptive_max_chunk_size: Option<usize>,
677 ) -> Result<crate::stream::JsonStream> {
678 async {
679 let startup_start = std::time::Instant::now();
680
681 use crate::json::validate_row_description;
682 use crate::stream::{extract_json_bytes, parse_json, AdaptiveChunking, ChunkingStrategy, JsonStream};
683 use serde_json::Value;
684 use tokio::sync::mpsc;
685
686 if self.state != ConnectionState::Idle {
687 return Err(Error::ConnectionBusy(format!(
688 "connection in state: {}",
689 self.state
690 )));
691 }
692
693 self.state.transition(ConnectionState::QueryInProgress)?;
694
695 let query_msg = FrontendMessage::Query(query.to_string());
696 self.send_message(&query_msg).await?;
697
698 self.state.transition(ConnectionState::ReadingResults)?;
699
700 let row_desc;
703 loop {
704 let msg = self.receive_message().await?;
705
706 match msg {
707 BackendMessage::ErrorResponse(err) => {
708 tracing::debug!("PostgreSQL error response: {}", err);
710 loop {
711 let msg = self.receive_message().await?;
712 if matches!(msg, BackendMessage::ReadyForQuery { .. }) {
713 break;
714 }
715 }
716 return Err(Error::Sql(err.to_string()));
717 }
718 BackendMessage::BackendKeyData { process_id, secret_key: _ } => {
719 tracing::debug!("PostgreSQL backend key data received: pid={}", process_id);
721 continue;
723 }
724 BackendMessage::ParameterStatus { .. } => {
725 tracing::debug!("PostgreSQL parameter status change received");
727 continue;
728 }
729 BackendMessage::NoticeResponse(notice) => {
730 tracing::debug!("PostgreSQL notice: {}", notice);
732 continue;
733 }
734 BackendMessage::RowDescription(_) => {
735 row_desc = msg;
736 break;
737 }
738 BackendMessage::ReadyForQuery { .. } => {
739 return Err(Error::Protocol(
742 "no result set received from query - \
743 check that the entity name is correct and the table/view exists"
744 .into(),
745 ));
746 }
747 _ => {
748 return Err(Error::Protocol(format!(
749 "unexpected message type in query response: {:?}",
750 msg
751 )));
752 }
753 }
754 }
755
756 validate_row_description(&row_desc)?;
757
758 let startup_duration = startup_start.elapsed().as_millis() as u64;
760 let entity = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
761 crate::metrics::histograms::query_startup_duration(&entity, startup_duration);
762
763 let (result_tx, result_rx) = mpsc::channel::<Result<Value>>(chunk_size);
765 let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
766
767 let entity_for_metrics = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
769 let entity_for_stream = entity_for_metrics.clone(); let stream = JsonStream::new(
772 result_rx,
773 cancel_tx,
774 entity_for_stream,
775 max_memory,
776 soft_limit_warn_threshold,
777 soft_limit_fail_threshold,
778 );
779
780 let state_lock = stream.clone_state();
782 let pause_signal = stream.clone_pause_signal();
783 let resume_signal = stream.clone_resume_signal();
784
785 let state_atomic = stream.clone_state_atomic();
787
788 let pause_timeout = stream.pause_timeout();
790
791 let query_start = std::time::Instant::now();
793
794 tokio::spawn(async move {
795 let strategy = ChunkingStrategy::new(chunk_size);
796 let mut chunk = strategy.new_chunk();
797 let mut total_rows = 0u64;
798
799 let _adaptive = if enable_adaptive_chunking {
801 let mut adp = AdaptiveChunking::new();
802
803 if let Some(min) = adaptive_min_chunk_size {
805 if let Some(max) = adaptive_max_chunk_size {
806 adp = adp.with_bounds(min, max);
807 }
808 }
809
810 Some(adp)
811 } else {
812 None
813 };
814 let _current_chunk_size = chunk_size;
815
816 loop {
817 if state_lock.is_some() && state_atomic.load(std::sync::atomic::Ordering::Acquire) == 1 {
820 if let (Some(ref state_lock), Some(ref _pause_signal), Some(ref resume_signal)) =
822 (&state_lock, &pause_signal, &resume_signal)
823 {
824 let current_state = state_lock.lock().await;
825 if *current_state == crate::stream::StreamState::Paused {
826 tracing::debug!("stream paused, waiting for resume");
827 drop(current_state); if let Some(timeout) = pause_timeout {
831 match tokio::time::timeout(timeout, resume_signal.notified()).await {
832 Ok(_) => {
833 tracing::debug!("stream resumed");
834 }
835 Err(_) => {
836 tracing::debug!("pause timeout expired, auto-resuming");
837 crate::metrics::counters::stream_pause_timeout_expired(&entity_for_metrics);
838 }
839 }
840 } else {
841 resume_signal.notified().await;
843 tracing::debug!("stream resumed");
844 }
845
846 let mut state = state_lock.lock().await;
848 *state = crate::stream::StreamState::Running;
849 }
850 }
851 }
852
853 tokio::select! {
854 _ = cancel_rx.recv() => {
856 tracing::debug!("query cancelled");
857 crate::metrics::counters::query_completed("cancelled", &entity_for_metrics);
858 break;
859 }
860
861 msg_result = self.receive_message() => {
863 match msg_result {
864 Ok(msg) => match msg {
865 BackendMessage::DataRow(_) => {
866 match extract_json_bytes(&msg) {
867 Ok(json_bytes) => {
868 chunk.push(json_bytes);
869
870 if strategy.is_full(&chunk) {
871 let chunk_start = std::time::Instant::now();
872 let rows = chunk.into_rows();
873 let chunk_size_rows = rows.len() as u64;
874
875 const BATCH_SIZE: usize = 8;
878 let mut batch = Vec::with_capacity(BATCH_SIZE);
879 let mut send_error = false;
880
881 for row_bytes in rows {
882 match parse_json(row_bytes) {
883 Ok(value) => {
884 total_rows += 1;
885 batch.push(Ok(value));
886
887 if batch.len() == BATCH_SIZE {
889 for item in batch.drain(..) {
890 if result_tx.send(item).await.is_err() {
891 crate::metrics::counters::query_completed("error", &entity_for_metrics);
892 send_error = true;
893 break;
894 }
895 }
896 if send_error {
897 break;
898 }
899 }
900 }
901 Err(e) => {
902 crate::metrics::counters::json_parse_error(&entity_for_metrics);
903 let _ = result_tx.send(Err(e)).await;
904 crate::metrics::counters::query_completed("error", &entity_for_metrics);
905 send_error = true;
906 break;
907 }
908 }
909 }
910
911 if !send_error {
913 for item in batch {
914 if result_tx.send(item).await.is_err() {
915 crate::metrics::counters::query_completed("error", &entity_for_metrics);
916 break;
917 }
918 }
919 }
920
921 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
923
924 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
926 if chunk_idx % 10 == 0 {
927 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
928 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
929 }
930
931 chunk = strategy.new_chunk();
937 }
938 }
939 Err(e) => {
940 crate::metrics::counters::json_parse_error(&entity_for_metrics);
941 let _ = result_tx.send(Err(e)).await;
942 crate::metrics::counters::query_completed("error", &entity_for_metrics);
943 break;
944 }
945 }
946 }
947 BackendMessage::CommandComplete(_) => {
948 if !chunk.is_empty() {
950 let chunk_start = std::time::Instant::now();
951 let rows = chunk.into_rows();
952 let chunk_size_rows = rows.len() as u64;
953
954 const BATCH_SIZE: usize = 8;
956 let mut batch = Vec::with_capacity(BATCH_SIZE);
957 let mut send_error = false;
958
959 for row_bytes in rows {
960 match parse_json(row_bytes) {
961 Ok(value) => {
962 total_rows += 1;
963 batch.push(Ok(value));
964
965 if batch.len() == BATCH_SIZE {
967 for item in batch.drain(..) {
968 if result_tx.send(item).await.is_err() {
969 crate::metrics::counters::query_completed("error", &entity_for_metrics);
970 send_error = true;
971 break;
972 }
973 }
974 if send_error {
975 break;
976 }
977 }
978 }
979 Err(e) => {
980 crate::metrics::counters::json_parse_error(&entity_for_metrics);
981 let _ = result_tx.send(Err(e)).await;
982 crate::metrics::counters::query_completed("error", &entity_for_metrics);
983 send_error = true;
984 break;
985 }
986 }
987 }
988
989 if !send_error {
991 for item in batch {
992 if result_tx.send(item).await.is_err() {
993 crate::metrics::counters::query_completed("error", &entity_for_metrics);
994 break;
995 }
996 }
997 }
998
999 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
1001 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
1002 if chunk_idx % 10 == 0 {
1003 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
1004 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
1005 }
1006 chunk = strategy.new_chunk();
1007 }
1008
1009 let query_duration = query_start.elapsed().as_millis() as u64;
1011 crate::metrics::counters::rows_processed(&entity_for_metrics, total_rows, "ok");
1012 crate::metrics::histograms::query_total_duration(&entity_for_metrics, query_duration);
1013 crate::metrics::counters::query_completed("success", &entity_for_metrics);
1014 }
1015 BackendMessage::ReadyForQuery { .. } => {
1016 break;
1017 }
1018 BackendMessage::ErrorResponse(err) => {
1019 crate::metrics::counters::query_error(&entity_for_metrics, "server_error");
1020 crate::metrics::counters::query_completed("error", &entity_for_metrics);
1021 let _ = result_tx.send(Err(Error::Sql(err.to_string()))).await;
1022 break;
1023 }
1024 _ => {
1025 crate::metrics::counters::query_error(&entity_for_metrics, "protocol_error");
1026 crate::metrics::counters::query_completed("error", &entity_for_metrics);
1027 let _ = result_tx.send(Err(Error::Protocol(
1028 format!("unexpected message: {:?}", msg)
1029 ))).await;
1030 break;
1031 }
1032 },
1033 Err(e) => {
1034 crate::metrics::counters::query_error(&entity_for_metrics, "connection_error");
1035 crate::metrics::counters::query_completed("error", &entity_for_metrics);
1036 let _ = result_tx.send(Err(e)).await;
1037 break;
1038 }
1039 }
1040 }
1041 }
1042 }
1043 });
1044
1045 Ok(stream)
1046 }
1047 .instrument(tracing::debug_span!(
1048 "streaming_query",
1049 query = %query,
1050 chunk_size = %chunk_size
1051 ))
1052 .await
1053 }
1054}
1055
1056fn extract_entity_from_query(query: &str) -> Option<String> {
1059 let query_lower = query.to_lowercase();
1060 if let Some(from_pos) = query_lower.find("from") {
1061 let after_from = &query_lower[from_pos + 4..].trim_start();
1062 if let Some(entity_start) = after_from.find('v').or_else(|| after_from.find('t')) {
1063 let potential_table = &after_from[entity_start..];
1064 let end_pos = potential_table
1066 .find(' ')
1067 .or_else(|| potential_table.find(';'))
1068 .unwrap_or(potential_table.len());
1069 let table_name = &potential_table[..end_pos];
1070 if let Some(entity_pos) = table_name.rfind('_') {
1072 return Some(table_name[entity_pos + 1..].to_string());
1073 }
1074 }
1075 }
1076 None
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081 use super::*;
1082
1083 #[test]
1084 fn test_connection_config() {
1085 let config = ConnectionConfig::new("testdb", "testuser")
1086 .password("testpass")
1087 .param("application_name", "fraiseql-wire");
1088
1089 assert_eq!(config.database, "testdb");
1090 assert_eq!(config.user, "testuser");
1091 assert_eq!(config.password, Some("testpass".to_string()));
1092 assert_eq!(
1093 config.params.get("application_name"),
1094 Some(&"fraiseql-wire".to_string())
1095 );
1096 }
1097
1098 #[test]
1099 fn test_connection_config_builder_basic() {
1100 let config = ConnectionConfig::builder("mydb", "myuser")
1101 .password("mypass")
1102 .build();
1103
1104 assert_eq!(config.database, "mydb");
1105 assert_eq!(config.user, "myuser");
1106 assert_eq!(config.password, Some("mypass".to_string()));
1107 assert_eq!(config.connect_timeout, None);
1108 assert_eq!(config.statement_timeout, None);
1109 assert_eq!(config.keepalive_idle, None);
1110 assert_eq!(config.application_name, None);
1111 }
1112
1113 #[test]
1114 fn test_connection_config_builder_with_timeouts() {
1115 let connect_timeout = Duration::from_secs(10);
1116 let statement_timeout = Duration::from_secs(30);
1117 let keepalive_idle = Duration::from_secs(300);
1118
1119 let config = ConnectionConfig::builder("mydb", "myuser")
1120 .password("mypass")
1121 .connect_timeout(connect_timeout)
1122 .statement_timeout(statement_timeout)
1123 .keepalive_idle(keepalive_idle)
1124 .build();
1125
1126 assert_eq!(config.connect_timeout, Some(connect_timeout));
1127 assert_eq!(config.statement_timeout, Some(statement_timeout));
1128 assert_eq!(config.keepalive_idle, Some(keepalive_idle));
1129 }
1130
1131 #[test]
1132 fn test_connection_config_builder_with_application_name() {
1133 let config = ConnectionConfig::builder("mydb", "myuser")
1134 .application_name("my_app")
1135 .extra_float_digits(2)
1136 .build();
1137
1138 assert_eq!(config.application_name, Some("my_app".to_string()));
1139 assert_eq!(config.extra_float_digits, Some(2));
1140 }
1141
1142 #[test]
1143 fn test_connection_config_builder_fluent() {
1144 let config = ConnectionConfig::builder("mydb", "myuser")
1145 .password("secret")
1146 .param("key1", "value1")
1147 .connect_timeout(Duration::from_secs(5))
1148 .statement_timeout(Duration::from_secs(60))
1149 .application_name("test_app")
1150 .build();
1151
1152 assert_eq!(config.database, "mydb");
1153 assert_eq!(config.user, "myuser");
1154 assert_eq!(config.password, Some("secret".to_string()));
1155 assert_eq!(config.params.get("key1"), Some(&"value1".to_string()));
1156 assert_eq!(config.connect_timeout, Some(Duration::from_secs(5)));
1157 assert_eq!(config.statement_timeout, Some(Duration::from_secs(60)));
1158 assert_eq!(config.application_name, Some("test_app".to_string()));
1159 }
1160
1161 #[test]
1162 fn test_connection_config_defaults() {
1163 let config = ConnectionConfig::new("db", "user");
1164
1165 assert!(config.connect_timeout.is_none());
1166 assert!(config.statement_timeout.is_none());
1167 assert!(config.keepalive_idle.is_none());
1168 assert!(config.application_name.is_none());
1169 assert!(config.extra_float_digits.is_none());
1170 assert_eq!(config.sslmode, super::SslMode::Disable);
1171 }
1172
1173 #[test]
1174 fn test_connection_config_builder_with_sslmode() {
1175 let config = ConnectionConfig::builder("mydb", "myuser")
1176 .sslmode(super::SslMode::VerifyFull)
1177 .build();
1178
1179 assert_eq!(config.sslmode, super::SslMode::VerifyFull);
1180 }
1181
1182 #[allow(dead_code)]
1186 const _SEND_SAFETY_CHECK: fn() = || {
1187 fn require_send<T: Send>() {}
1188
1189 #[allow(unreachable_code)]
1191 let _ = || {
1192 require_send::<
1194 std::pin::Pin<std::boxed::Box<dyn std::future::Future<Output = ()> + Send>>,
1195 >();
1196 };
1197 };
1198}