1use super::state::ConnectionState;
4use super::transport::Transport;
5use crate::auth::ScramClient;
6use crate::protocol::{
7 decode_message, encode_message, AuthenticationMessage, BackendMessage, FrontendMessage,
8};
9use crate::{Error, Result};
10use bytes::{Buf, BytesMut};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::time::Duration;
14use tracing::Instrument;
15use zeroize::Zeroizing;
16
17static CHUNK_COUNT: AtomicU64 = AtomicU64::new(0);
20
21#[derive(Debug, Clone)]
26pub struct ConnectionConfig {
27 pub database: String,
29 pub user: String,
31 pub password: Option<Zeroizing<String>>,
33 pub params: HashMap<String, String>,
35 pub connect_timeout: Option<Duration>,
37 pub statement_timeout: Option<Duration>,
39 pub keepalive_idle: Option<Duration>,
41 pub application_name: Option<String>,
43 pub extra_float_digits: Option<i32>,
45}
46
47impl ConnectionConfig {
48 pub fn new(database: impl Into<String>, user: impl Into<String>) -> Self {
65 Self {
66 database: database.into(),
67 user: user.into(),
68 password: None,
69 params: HashMap::new(),
70 connect_timeout: None,
71 statement_timeout: None,
72 keepalive_idle: None,
73 application_name: None,
74 extra_float_digits: None,
75 }
76 }
77
78 pub fn builder(
91 database: impl Into<String>,
92 user: impl Into<String>,
93 ) -> ConnectionConfigBuilder {
94 ConnectionConfigBuilder {
95 database: database.into(),
96 user: user.into(),
97 password: None,
98 params: HashMap::new(),
99 connect_timeout: None,
100 statement_timeout: None,
101 keepalive_idle: None,
102 application_name: None,
103 extra_float_digits: None,
104 }
105 }
106
107 pub fn password(mut self, password: impl Into<String>) -> Self {
109 self.password = Some(Zeroizing::new(password.into()));
110 self
111 }
112
113 pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
115 self.params.insert(key.into(), value.into());
116 self
117 }
118}
119
120#[derive(Debug, Clone)]
136pub struct ConnectionConfigBuilder {
137 database: String,
138 user: String,
139 password: Option<Zeroizing<String>>,
140 params: HashMap<String, String>,
141 connect_timeout: Option<Duration>,
142 statement_timeout: Option<Duration>,
143 keepalive_idle: Option<Duration>,
144 application_name: Option<String>,
145 extra_float_digits: Option<i32>,
146}
147
148impl ConnectionConfigBuilder {
149 pub fn password(mut self, password: impl Into<String>) -> Self {
151 self.password = Some(Zeroizing::new(password.into()));
152 self
153 }
154
155 pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
157 self.params.insert(key.into(), value.into());
158 self
159 }
160
161 pub fn connect_timeout(mut self, duration: Duration) -> Self {
169 self.connect_timeout = Some(duration);
170 self
171 }
172
173 pub fn statement_timeout(mut self, duration: Duration) -> Self {
181 self.statement_timeout = Some(duration);
182 self
183 }
184
185 pub fn keepalive_idle(mut self, duration: Duration) -> Self {
193 self.keepalive_idle = Some(duration);
194 self
195 }
196
197 pub fn application_name(mut self, name: impl Into<String>) -> Self {
205 self.application_name = Some(name.into());
206 self
207 }
208
209 pub fn extra_float_digits(mut self, digits: i32) -> Self {
217 self.extra_float_digits = Some(digits);
218 self
219 }
220
221 pub fn build(self) -> ConnectionConfig {
223 ConnectionConfig {
224 database: self.database,
225 user: self.user,
226 password: self.password,
227 params: self.params,
228 connect_timeout: self.connect_timeout,
229 statement_timeout: self.statement_timeout,
230 keepalive_idle: self.keepalive_idle,
231 application_name: self.application_name,
232 extra_float_digits: self.extra_float_digits,
233 }
234 }
235}
236
237pub struct Connection {
239 transport: Transport,
240 state: ConnectionState,
241 read_buf: BytesMut,
242 process_id: Option<i32>,
243 secret_key: Option<i32>,
244}
245
246impl Connection {
247 pub fn new(transport: Transport) -> Self {
249 Self {
250 transport,
251 state: ConnectionState::Initial,
252 read_buf: BytesMut::with_capacity(8192),
253 process_id: None,
254 secret_key: None,
255 }
256 }
257
258 pub fn state(&self) -> ConnectionState {
260 self.state
261 }
262
263 pub async fn startup(&mut self, config: &ConnectionConfig) -> Result<()> {
265 async {
266 self.state.transition(ConnectionState::AwaitingAuth)?;
267
268 let mut params = vec![
270 ("user".to_string(), config.user.clone()),
271 ("database".to_string(), config.database.clone()),
272 ];
273
274 if let Some(app_name) = &config.application_name {
276 params.push(("application_name".to_string(), app_name.clone()));
277 }
278
279 if let Some(timeout) = config.statement_timeout {
281 params.push((
282 "statement_timeout".to_string(),
283 timeout.as_millis().to_string(),
284 ));
285 }
286
287 if let Some(digits) = config.extra_float_digits {
289 params.push(("extra_float_digits".to_string(), digits.to_string()));
290 }
291
292 for (k, v) in &config.params {
294 params.push((k.clone(), v.clone()));
295 }
296
297 let startup = FrontendMessage::Startup {
299 version: crate::protocol::constants::PROTOCOL_VERSION,
300 params,
301 };
302 self.send_message(&startup).await?;
303
304 self.state.transition(ConnectionState::Authenticating)?;
306 self.authenticate(config).await?;
307
308 self.state.transition(ConnectionState::Idle)?;
309 tracing::info!("startup complete");
310 Ok(())
311 }
312 .instrument(tracing::info_span!(
313 "startup",
314 user = %config.user,
315 database = %config.database
316 ))
317 .await
318 }
319
320 async fn authenticate(&mut self, config: &ConnectionConfig) -> Result<()> {
322 let auth_start = std::time::Instant::now();
323 let mut auth_mechanism = "unknown";
324
325 loop {
326 let msg = self.receive_message().await?;
327
328 match msg {
329 BackendMessage::Authentication(auth) => match auth {
330 AuthenticationMessage::Ok => {
331 tracing::debug!("authentication successful");
332 crate::metrics::counters::auth_successful(auth_mechanism);
333 crate::metrics::histograms::auth_duration(
334 auth_mechanism,
335 auth_start.elapsed().as_millis() as u64,
336 );
337 }
339 AuthenticationMessage::CleartextPassword => {
340 auth_mechanism = crate::metrics::labels::MECHANISM_CLEARTEXT;
341 crate::metrics::counters::auth_attempted(auth_mechanism);
342
343 let password = config
344 .password
345 .as_ref()
346 .ok_or_else(|| Error::Authentication("password required".into()))?;
347 let pwd_msg = FrontendMessage::Password(password.as_str().to_string());
349 self.send_message(&pwd_msg).await?;
350 }
351 AuthenticationMessage::Md5Password { .. } => {
352 return Err(Error::Authentication(
353 "MD5 authentication not supported. Use SCRAM-SHA-256 or cleartext password".into(),
354 ));
355 }
356 AuthenticationMessage::Sasl { mechanisms } => {
357 auth_mechanism = crate::metrics::labels::MECHANISM_SCRAM;
358 crate::metrics::counters::auth_attempted(auth_mechanism);
359 self.handle_sasl(&mechanisms, config).await?;
360 }
361 AuthenticationMessage::SaslContinue { .. } => {
362 return Err(Error::Protocol(
363 "unexpected SaslContinue outside of SASL flow".into(),
364 ));
365 }
366 AuthenticationMessage::SaslFinal { .. } => {
367 return Err(Error::Protocol(
368 "unexpected SaslFinal outside of SASL flow".into(),
369 ));
370 }
371 },
372 BackendMessage::BackendKeyData {
373 process_id,
374 secret_key,
375 } => {
376 self.process_id = Some(process_id);
377 self.secret_key = Some(secret_key);
378 }
379 BackendMessage::ParameterStatus { name, value } => {
380 tracing::debug!("parameter status: {} = {}", name, value);
381 }
382 BackendMessage::ReadyForQuery { status: _ } => {
383 break;
384 }
385 BackendMessage::ErrorResponse(err) => {
386 crate::metrics::counters::auth_failed(auth_mechanism, "server_error");
387 return Err(Error::Authentication(err.to_string()));
388 }
389 _ => {
390 return Err(Error::Protocol(format!(
391 "unexpected message during auth: {:?}",
392 msg
393 )));
394 }
395 }
396 }
397
398 Ok(())
399 }
400
401 async fn handle_sasl(
403 &mut self,
404 mechanisms: &[String],
405 config: &ConnectionConfig,
406 ) -> Result<()> {
407 if !mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
409 return Err(Error::Authentication(format!(
410 "server does not support SCRAM-SHA-256. Available: {}",
411 mechanisms.join(", ")
412 )));
413 }
414
415 let password = config.password.as_ref().ok_or_else(|| {
417 Error::Authentication("password required for SCRAM authentication".into())
418 })?;
419
420 let mut scram = ScramClient::new(config.user.clone(), password.as_str().to_string());
423 tracing::debug!("initiating SCRAM-SHA-256 authentication");
424
425 let client_first = scram.client_first();
427 let msg = FrontendMessage::SaslInitialResponse {
428 mechanism: "SCRAM-SHA-256".to_string(),
429 data: client_first.into_bytes(),
430 };
431 self.send_message(&msg).await?;
432
433 let server_first_msg = self.receive_message().await?;
435 let server_first_data = match server_first_msg {
436 BackendMessage::Authentication(AuthenticationMessage::SaslContinue { data }) => data,
437 BackendMessage::ErrorResponse(err) => {
438 return Err(Error::Authentication(format!("SASL server error: {}", err)));
439 }
440 _ => {
441 return Err(Error::Protocol(
442 "expected SaslContinue message during SASL authentication".into(),
443 ));
444 }
445 };
446
447 let server_first = String::from_utf8(server_first_data).map_err(|e| {
448 Error::Authentication(format!("invalid UTF-8 in server first message: {}", e))
449 })?;
450
451 tracing::debug!("received SCRAM server first message");
452
453 let (client_final, scram_state) = scram
455 .client_final(&server_first)
456 .map_err(|e| Error::Authentication(format!("SCRAM error: {}", e)))?;
457
458 let msg = FrontendMessage::SaslResponse {
460 data: client_final.into_bytes(),
461 };
462 self.send_message(&msg).await?;
463
464 let server_final_msg = self.receive_message().await?;
466 let server_final_data = match server_final_msg {
467 BackendMessage::Authentication(AuthenticationMessage::SaslFinal { data }) => data,
468 BackendMessage::ErrorResponse(err) => {
469 return Err(Error::Authentication(format!("SASL server error: {}", err)));
470 }
471 _ => {
472 return Err(Error::Protocol(
473 "expected SaslFinal message during SASL authentication".into(),
474 ));
475 }
476 };
477
478 let server_final = String::from_utf8(server_final_data).map_err(|e| {
479 Error::Authentication(format!("invalid UTF-8 in server final message: {}", e))
480 })?;
481
482 scram
484 .verify_server_final(&server_final, &scram_state)
485 .map_err(|e| Error::Authentication(format!("SCRAM verification failed: {}", e)))?;
486
487 tracing::debug!("SCRAM-SHA-256 authentication successful");
488 Ok(())
489 }
490
491 pub async fn simple_query(&mut self, query: &str) -> Result<Vec<BackendMessage>> {
493 if self.state != ConnectionState::Idle {
494 return Err(Error::ConnectionBusy(format!(
495 "connection in state: {}",
496 self.state
497 )));
498 }
499
500 self.state.transition(ConnectionState::QueryInProgress)?;
501
502 let query_msg = FrontendMessage::Query(query.to_string());
503 self.send_message(&query_msg).await?;
504
505 self.state.transition(ConnectionState::ReadingResults)?;
506
507 let mut messages = Vec::new();
508
509 loop {
510 let msg = self.receive_message().await?;
511 let is_ready = matches!(msg, BackendMessage::ReadyForQuery { .. });
512 messages.push(msg);
513
514 if is_ready {
515 break;
516 }
517 }
518
519 self.state.transition(ConnectionState::Idle)?;
520 Ok(messages)
521 }
522
523 async fn send_message(&mut self, msg: &FrontendMessage) -> Result<()> {
525 let buf = encode_message(msg)?;
526 self.transport.write_all(&buf).await?;
527 self.transport.flush().await?;
528 Ok(())
529 }
530
531 async fn receive_message(&mut self) -> Result<BackendMessage> {
533 loop {
534 if let Ok((msg, consumed)) = decode_message(&mut self.read_buf) {
536 self.read_buf.advance(consumed);
537 return Ok(msg);
538 }
539
540 let n = self.transport.read_buf(&mut self.read_buf).await?;
542 if n == 0 {
543 return Err(Error::ConnectionClosed);
544 }
545 }
546 }
547
548 pub async fn close(mut self) -> Result<()> {
550 self.state.transition(ConnectionState::Closed)?;
551 let _ = self.send_message(&FrontendMessage::Terminate).await;
552 self.transport.shutdown().await?;
553 Ok(())
554 }
555
556 #[allow(clippy::too_many_arguments)]
561 pub async fn streaming_query(
562 mut self,
563 query: &str,
564 chunk_size: usize,
565 max_memory: Option<usize>,
566 soft_limit_warn_threshold: Option<f32>,
567 soft_limit_fail_threshold: Option<f32>,
568 enable_adaptive_chunking: bool,
569 adaptive_min_chunk_size: Option<usize>,
570 adaptive_max_chunk_size: Option<usize>,
571 ) -> Result<crate::stream::JsonStream> {
572 async {
573 let startup_start = std::time::Instant::now();
574
575 use crate::json::validate_row_description;
576 use crate::stream::{extract_json_bytes, parse_json, AdaptiveChunking, ChunkingStrategy, JsonStream};
577 use serde_json::Value;
578 use tokio::sync::mpsc;
579
580 if self.state != ConnectionState::Idle {
581 return Err(Error::ConnectionBusy(format!(
582 "connection in state: {}",
583 self.state
584 )));
585 }
586
587 self.state.transition(ConnectionState::QueryInProgress)?;
588
589 let query_msg = FrontendMessage::Query(query.to_string());
590 self.send_message(&query_msg).await?;
591
592 self.state.transition(ConnectionState::ReadingResults)?;
593
594 let row_desc;
597 loop {
598 let msg = self.receive_message().await?;
599
600 match msg {
601 BackendMessage::ErrorResponse(err) => {
602 tracing::debug!("PostgreSQL error response: {}", err);
604 loop {
605 let msg = self.receive_message().await?;
606 if matches!(msg, BackendMessage::ReadyForQuery { .. }) {
607 break;
608 }
609 }
610 return Err(Error::Sql(err.to_string()));
611 }
612 BackendMessage::BackendKeyData { process_id, secret_key: _ } => {
613 tracing::debug!("PostgreSQL backend key data received: pid={}", process_id);
615 continue;
617 }
618 BackendMessage::ParameterStatus { .. } => {
619 tracing::debug!("PostgreSQL parameter status change received");
621 continue;
622 }
623 BackendMessage::NoticeResponse(notice) => {
624 tracing::debug!("PostgreSQL notice: {}", notice);
626 continue;
627 }
628 BackendMessage::RowDescription(_) => {
629 row_desc = msg;
630 break;
631 }
632 BackendMessage::ReadyForQuery { .. } => {
633 return Err(Error::Protocol(
636 "no result set received from query - \
637 check that the entity name is correct and the table/view exists"
638 .into(),
639 ));
640 }
641 _ => {
642 return Err(Error::Protocol(format!(
643 "unexpected message type in query response: {:?}",
644 msg
645 )));
646 }
647 }
648 }
649
650 validate_row_description(&row_desc)?;
651
652 let startup_duration = startup_start.elapsed().as_millis() as u64;
654 let entity = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
655 crate::metrics::histograms::query_startup_duration(&entity, startup_duration);
656
657 let (result_tx, result_rx) = mpsc::channel::<Result<Value>>(chunk_size);
659 let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
660
661 let entity_for_metrics = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
663 let entity_for_stream = entity_for_metrics.clone(); let stream = JsonStream::new(
666 result_rx,
667 cancel_tx,
668 entity_for_stream,
669 max_memory,
670 soft_limit_warn_threshold,
671 soft_limit_fail_threshold,
672 );
673
674 let state_lock = stream.clone_state();
676 let pause_signal = stream.clone_pause_signal();
677 let resume_signal = stream.clone_resume_signal();
678
679 let state_atomic = stream.clone_state_atomic();
681
682 let pause_timeout = stream.pause_timeout();
684
685 let query_start = std::time::Instant::now();
687
688 tokio::spawn(async move {
689 let strategy = ChunkingStrategy::new(chunk_size);
690 let mut chunk = strategy.new_chunk();
691 let mut total_rows = 0u64;
692
693 let _adaptive = if enable_adaptive_chunking {
695 let mut adp = AdaptiveChunking::new();
696
697 if let Some(min) = adaptive_min_chunk_size {
699 if let Some(max) = adaptive_max_chunk_size {
700 adp = adp.with_bounds(min, max);
701 }
702 }
703
704 Some(adp)
705 } else {
706 None
707 };
708 let _current_chunk_size = chunk_size;
709
710 loop {
711 if state_lock.is_some() && state_atomic.load(std::sync::atomic::Ordering::Acquire) == 1 {
714 if let (Some(ref state_lock), Some(ref _pause_signal), Some(ref resume_signal)) =
716 (&state_lock, &pause_signal, &resume_signal)
717 {
718 let current_state = state_lock.lock().await;
719 if *current_state == crate::stream::StreamState::Paused {
720 tracing::debug!("stream paused, waiting for resume");
721 drop(current_state); if let Some(timeout) = pause_timeout {
725 match tokio::time::timeout(timeout, resume_signal.notified()).await {
726 Ok(_) => {
727 tracing::debug!("stream resumed");
728 }
729 Err(_) => {
730 tracing::debug!("pause timeout expired, auto-resuming");
731 crate::metrics::counters::stream_pause_timeout_expired(&entity_for_metrics);
732 }
733 }
734 } else {
735 resume_signal.notified().await;
737 tracing::debug!("stream resumed");
738 }
739
740 let mut state = state_lock.lock().await;
742 *state = crate::stream::StreamState::Running;
743 }
744 }
745 }
746
747 tokio::select! {
748 _ = cancel_rx.recv() => {
750 tracing::debug!("query cancelled");
751 crate::metrics::counters::query_completed("cancelled", &entity_for_metrics);
752 break;
753 }
754
755 msg_result = self.receive_message() => {
757 match msg_result {
758 Ok(msg) => match msg {
759 BackendMessage::DataRow(_) => {
760 match extract_json_bytes(&msg) {
761 Ok(json_bytes) => {
762 chunk.push(json_bytes);
763
764 if strategy.is_full(&chunk) {
765 let chunk_start = std::time::Instant::now();
766 let rows = chunk.into_rows();
767 let chunk_size_rows = rows.len() as u64;
768
769 const BATCH_SIZE: usize = 8;
772 let mut batch = Vec::with_capacity(BATCH_SIZE);
773 let mut send_error = false;
774
775 for row_bytes in rows {
776 match parse_json(row_bytes) {
777 Ok(value) => {
778 total_rows += 1;
779 batch.push(Ok(value));
780
781 if batch.len() == BATCH_SIZE {
783 for item in batch.drain(..) {
784 if result_tx.send(item).await.is_err() {
785 crate::metrics::counters::query_completed("error", &entity_for_metrics);
786 send_error = true;
787 break;
788 }
789 }
790 if send_error {
791 break;
792 }
793 }
794 }
795 Err(e) => {
796 crate::metrics::counters::json_parse_error(&entity_for_metrics);
797 let _ = result_tx.send(Err(e)).await;
798 crate::metrics::counters::query_completed("error", &entity_for_metrics);
799 send_error = true;
800 break;
801 }
802 }
803 }
804
805 if !send_error {
807 for item in batch {
808 if result_tx.send(item).await.is_err() {
809 crate::metrics::counters::query_completed("error", &entity_for_metrics);
810 break;
811 }
812 }
813 }
814
815 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
817
818 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
820 if chunk_idx % 10 == 0 {
821 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
822 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
823 }
824
825 chunk = strategy.new_chunk();
831 }
832 }
833 Err(e) => {
834 crate::metrics::counters::json_parse_error(&entity_for_metrics);
835 let _ = result_tx.send(Err(e)).await;
836 crate::metrics::counters::query_completed("error", &entity_for_metrics);
837 break;
838 }
839 }
840 }
841 BackendMessage::CommandComplete(_) => {
842 if !chunk.is_empty() {
844 let chunk_start = std::time::Instant::now();
845 let rows = chunk.into_rows();
846 let chunk_size_rows = rows.len() as u64;
847
848 const BATCH_SIZE: usize = 8;
850 let mut batch = Vec::with_capacity(BATCH_SIZE);
851 let mut send_error = false;
852
853 for row_bytes in rows {
854 match parse_json(row_bytes) {
855 Ok(value) => {
856 total_rows += 1;
857 batch.push(Ok(value));
858
859 if batch.len() == BATCH_SIZE {
861 for item in batch.drain(..) {
862 if result_tx.send(item).await.is_err() {
863 crate::metrics::counters::query_completed("error", &entity_for_metrics);
864 send_error = true;
865 break;
866 }
867 }
868 if send_error {
869 break;
870 }
871 }
872 }
873 Err(e) => {
874 crate::metrics::counters::json_parse_error(&entity_for_metrics);
875 let _ = result_tx.send(Err(e)).await;
876 crate::metrics::counters::query_completed("error", &entity_for_metrics);
877 send_error = true;
878 break;
879 }
880 }
881 }
882
883 if !send_error {
885 for item in batch {
886 if result_tx.send(item).await.is_err() {
887 crate::metrics::counters::query_completed("error", &entity_for_metrics);
888 break;
889 }
890 }
891 }
892
893 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
895 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
896 if chunk_idx % 10 == 0 {
897 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
898 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
899 }
900 chunk = strategy.new_chunk();
901 }
902
903 let query_duration = query_start.elapsed().as_millis() as u64;
905 crate::metrics::counters::rows_processed(&entity_for_metrics, total_rows, "ok");
906 crate::metrics::histograms::query_total_duration(&entity_for_metrics, query_duration);
907 crate::metrics::counters::query_completed("success", &entity_for_metrics);
908 }
909 BackendMessage::ReadyForQuery { .. } => {
910 break;
911 }
912 BackendMessage::ErrorResponse(err) => {
913 crate::metrics::counters::query_error(&entity_for_metrics, "server_error");
914 crate::metrics::counters::query_completed("error", &entity_for_metrics);
915 let _ = result_tx.send(Err(Error::Sql(err.to_string()))).await;
916 break;
917 }
918 _ => {
919 crate::metrics::counters::query_error(&entity_for_metrics, "protocol_error");
920 crate::metrics::counters::query_completed("error", &entity_for_metrics);
921 let _ = result_tx.send(Err(Error::Protocol(
922 format!("unexpected message: {:?}", msg)
923 ))).await;
924 break;
925 }
926 },
927 Err(e) => {
928 crate::metrics::counters::query_error(&entity_for_metrics, "connection_error");
929 crate::metrics::counters::query_completed("error", &entity_for_metrics);
930 let _ = result_tx.send(Err(e)).await;
931 break;
932 }
933 }
934 }
935 }
936 }
937 });
938
939 Ok(stream)
940 }
941 .instrument(tracing::debug_span!(
942 "streaming_query",
943 query = %query,
944 chunk_size = %chunk_size
945 ))
946 .await
947 }
948}
949
950fn extract_entity_from_query(query: &str) -> Option<String> {
953 let query_lower = query.to_lowercase();
954 if let Some(from_pos) = query_lower.find("from") {
955 let after_from = &query_lower[from_pos + 4..].trim_start();
956 if let Some(entity_start) = after_from.find('v').or_else(|| after_from.find('t')) {
957 let potential_table = &after_from[entity_start..];
958 let end_pos = potential_table
960 .find(' ')
961 .or_else(|| potential_table.find(';'))
962 .unwrap_or(potential_table.len());
963 let table_name = &potential_table[..end_pos];
964 if let Some(entity_pos) = table_name.rfind('_') {
966 return Some(table_name[entity_pos + 1..].to_string());
967 }
968 }
969 }
970 None
971}
972
973#[cfg(test)]
974mod tests {
975 use super::*;
976
977 #[test]
978 fn test_connection_config() {
979 let config = ConnectionConfig::new("testdb", "testuser")
980 .password("testpass")
981 .param("application_name", "fraiseql-wire");
982
983 assert_eq!(config.database, "testdb");
984 assert_eq!(config.user, "testuser");
985 assert_eq!(
986 config.password.as_ref().map(|p| p.as_str()),
987 Some("testpass")
988 );
989 assert_eq!(
990 config.params.get("application_name"),
991 Some(&"fraiseql-wire".to_string())
992 );
993 }
994
995 #[test]
996 fn test_connection_config_builder_basic() {
997 let config = ConnectionConfig::builder("mydb", "myuser")
998 .password("mypass")
999 .build();
1000
1001 assert_eq!(config.database, "mydb");
1002 assert_eq!(config.user, "myuser");
1003 assert_eq!(config.password.as_ref().map(|p| p.as_str()), Some("mypass"));
1004 assert_eq!(config.connect_timeout, None);
1005 assert_eq!(config.statement_timeout, None);
1006 assert_eq!(config.keepalive_idle, None);
1007 assert_eq!(config.application_name, None);
1008 }
1009
1010 #[test]
1011 fn test_connection_config_builder_with_timeouts() {
1012 let connect_timeout = Duration::from_secs(10);
1013 let statement_timeout = Duration::from_secs(30);
1014 let keepalive_idle = Duration::from_secs(300);
1015
1016 let config = ConnectionConfig::builder("mydb", "myuser")
1017 .password("mypass")
1018 .connect_timeout(connect_timeout)
1019 .statement_timeout(statement_timeout)
1020 .keepalive_idle(keepalive_idle)
1021 .build();
1022
1023 assert_eq!(config.connect_timeout, Some(connect_timeout));
1024 assert_eq!(config.statement_timeout, Some(statement_timeout));
1025 assert_eq!(config.keepalive_idle, Some(keepalive_idle));
1026 }
1027
1028 #[test]
1029 fn test_connection_config_builder_with_application_name() {
1030 let config = ConnectionConfig::builder("mydb", "myuser")
1031 .application_name("my_app")
1032 .extra_float_digits(2)
1033 .build();
1034
1035 assert_eq!(config.application_name, Some("my_app".to_string()));
1036 assert_eq!(config.extra_float_digits, Some(2));
1037 }
1038
1039 #[test]
1040 fn test_connection_config_builder_fluent() {
1041 let config = ConnectionConfig::builder("mydb", "myuser")
1042 .password("secret")
1043 .param("key1", "value1")
1044 .connect_timeout(Duration::from_secs(5))
1045 .statement_timeout(Duration::from_secs(60))
1046 .application_name("test_app")
1047 .build();
1048
1049 assert_eq!(config.database, "mydb");
1050 assert_eq!(config.user, "myuser");
1051 assert_eq!(config.password.as_ref().map(|p| p.as_str()), Some("secret"));
1052 assert_eq!(config.params.get("key1"), Some(&"value1".to_string()));
1053 assert_eq!(config.connect_timeout, Some(Duration::from_secs(5)));
1054 assert_eq!(config.statement_timeout, Some(Duration::from_secs(60)));
1055 assert_eq!(config.application_name, Some("test_app".to_string()));
1056 }
1057
1058 #[test]
1059 fn test_connection_config_defaults() {
1060 let config = ConnectionConfig::new("db", "user");
1061
1062 assert!(config.connect_timeout.is_none());
1063 assert!(config.statement_timeout.is_none());
1064 assert!(config.keepalive_idle.is_none());
1065 assert!(config.application_name.is_none());
1066 assert!(config.extra_float_digits.is_none());
1067 }
1068
1069 #[allow(dead_code)]
1073 const _SEND_SAFETY_CHECK: fn() = || {
1074 fn require_send<T: Send>() {}
1075
1076 #[allow(unreachable_code)]
1078 let _ = || {
1079 require_send::<
1081 std::pin::Pin<std::boxed::Box<dyn std::future::Future<Output = ()> + Send>>,
1082 >();
1083 };
1084 };
1085}