1use super::state::ConnectionState;
4use super::transport::Transport;
5use crate::auth::ScramClient;
6use crate::protocol::{
7 decode_message, encode_message, AuthenticationMessage, BackendMessage, FrontendMessage,
8};
9use crate::{Error, Result};
10use bytes::{Buf, BytesMut};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::time::Duration;
14use tracing::Instrument;
15
16static CHUNK_COUNT: AtomicU64 = AtomicU64::new(0);
19
20#[derive(Debug, Clone)]
25pub struct ConnectionConfig {
26 pub database: String,
28 pub user: String,
30 pub password: Option<String>,
32 pub params: HashMap<String, String>,
34 pub connect_timeout: Option<Duration>,
36 pub statement_timeout: Option<Duration>,
38 pub keepalive_idle: Option<Duration>,
40 pub application_name: Option<String>,
42 pub extra_float_digits: Option<i32>,
44}
45
46impl ConnectionConfig {
47 pub fn new(database: impl Into<String>, user: impl Into<String>) -> Self {
64 Self {
65 database: database.into(),
66 user: user.into(),
67 password: None,
68 params: HashMap::new(),
69 connect_timeout: None,
70 statement_timeout: None,
71 keepalive_idle: None,
72 application_name: None,
73 extra_float_digits: None,
74 }
75 }
76
77 pub fn builder(
90 database: impl Into<String>,
91 user: impl Into<String>,
92 ) -> ConnectionConfigBuilder {
93 ConnectionConfigBuilder {
94 database: database.into(),
95 user: user.into(),
96 password: None,
97 params: HashMap::new(),
98 connect_timeout: None,
99 statement_timeout: None,
100 keepalive_idle: None,
101 application_name: None,
102 extra_float_digits: None,
103 }
104 }
105
106 pub fn password(mut self, password: impl Into<String>) -> Self {
108 self.password = Some(password.into());
109 self
110 }
111
112 pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
114 self.params.insert(key.into(), value.into());
115 self
116 }
117}
118
119#[derive(Debug, Clone)]
135pub struct ConnectionConfigBuilder {
136 database: String,
137 user: String,
138 password: Option<String>,
139 params: HashMap<String, String>,
140 connect_timeout: Option<Duration>,
141 statement_timeout: Option<Duration>,
142 keepalive_idle: Option<Duration>,
143 application_name: Option<String>,
144 extra_float_digits: Option<i32>,
145}
146
147impl ConnectionConfigBuilder {
148 pub fn password(mut self, password: impl Into<String>) -> Self {
150 self.password = Some(password.into());
151 self
152 }
153
154 pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
156 self.params.insert(key.into(), value.into());
157 self
158 }
159
160 pub fn connect_timeout(mut self, duration: Duration) -> Self {
168 self.connect_timeout = Some(duration);
169 self
170 }
171
172 pub fn statement_timeout(mut self, duration: Duration) -> Self {
180 self.statement_timeout = Some(duration);
181 self
182 }
183
184 pub fn keepalive_idle(mut self, duration: Duration) -> Self {
192 self.keepalive_idle = Some(duration);
193 self
194 }
195
196 pub fn application_name(mut self, name: impl Into<String>) -> Self {
204 self.application_name = Some(name.into());
205 self
206 }
207
208 pub fn extra_float_digits(mut self, digits: i32) -> Self {
216 self.extra_float_digits = Some(digits);
217 self
218 }
219
220 pub fn build(self) -> ConnectionConfig {
222 ConnectionConfig {
223 database: self.database,
224 user: self.user,
225 password: self.password,
226 params: self.params,
227 connect_timeout: self.connect_timeout,
228 statement_timeout: self.statement_timeout,
229 keepalive_idle: self.keepalive_idle,
230 application_name: self.application_name,
231 extra_float_digits: self.extra_float_digits,
232 }
233 }
234}
235
236pub struct Connection {
238 transport: Transport,
239 state: ConnectionState,
240 read_buf: BytesMut,
241 process_id: Option<i32>,
242 secret_key: Option<i32>,
243}
244
245impl Connection {
246 pub fn new(transport: Transport) -> Self {
248 Self {
249 transport,
250 state: ConnectionState::Initial,
251 read_buf: BytesMut::with_capacity(8192),
252 process_id: None,
253 secret_key: None,
254 }
255 }
256
257 pub fn state(&self) -> ConnectionState {
259 self.state
260 }
261
262 pub async fn startup(&mut self, config: &ConnectionConfig) -> Result<()> {
264 async {
265 self.state.transition(ConnectionState::AwaitingAuth)?;
266
267 let mut params = vec![
269 ("user".to_string(), config.user.clone()),
270 ("database".to_string(), config.database.clone()),
271 ];
272
273 if let Some(app_name) = &config.application_name {
275 params.push(("application_name".to_string(), app_name.clone()));
276 }
277
278 if let Some(timeout) = config.statement_timeout {
280 params.push((
281 "statement_timeout".to_string(),
282 timeout.as_millis().to_string(),
283 ));
284 }
285
286 if let Some(digits) = config.extra_float_digits {
288 params.push(("extra_float_digits".to_string(), digits.to_string()));
289 }
290
291 for (k, v) in &config.params {
293 params.push((k.clone(), v.clone()));
294 }
295
296 let startup = FrontendMessage::Startup {
298 version: crate::protocol::constants::PROTOCOL_VERSION,
299 params,
300 };
301 self.send_message(&startup).await?;
302
303 self.state.transition(ConnectionState::Authenticating)?;
305 self.authenticate(config).await?;
306
307 self.state.transition(ConnectionState::Idle)?;
308 tracing::info!("startup complete");
309 Ok(())
310 }
311 .instrument(tracing::info_span!(
312 "startup",
313 user = %config.user,
314 database = %config.database
315 ))
316 .await
317 }
318
319 async fn authenticate(&mut self, config: &ConnectionConfig) -> Result<()> {
321 let auth_start = std::time::Instant::now();
322 let mut auth_mechanism = "unknown";
323
324 loop {
325 let msg = self.receive_message().await?;
326
327 match msg {
328 BackendMessage::Authentication(auth) => match auth {
329 AuthenticationMessage::Ok => {
330 tracing::debug!("authentication successful");
331 crate::metrics::counters::auth_successful(auth_mechanism);
332 crate::metrics::histograms::auth_duration(
333 auth_mechanism,
334 auth_start.elapsed().as_millis() as u64,
335 );
336 }
338 AuthenticationMessage::CleartextPassword => {
339 auth_mechanism = crate::metrics::labels::MECHANISM_CLEARTEXT;
340 crate::metrics::counters::auth_attempted(auth_mechanism);
341
342 let password = config
343 .password
344 .as_ref()
345 .ok_or_else(|| Error::Authentication("password required".into()))?;
346 let pwd_msg = FrontendMessage::Password(password.clone());
347 self.send_message(&pwd_msg).await?;
348 }
349 AuthenticationMessage::Md5Password { .. } => {
350 return Err(Error::Authentication(
351 "MD5 authentication not supported. Use SCRAM-SHA-256 or cleartext password".into(),
352 ));
353 }
354 AuthenticationMessage::Sasl { mechanisms } => {
355 auth_mechanism = crate::metrics::labels::MECHANISM_SCRAM;
356 crate::metrics::counters::auth_attempted(auth_mechanism);
357 self.handle_sasl(&mechanisms, config).await?;
358 }
359 AuthenticationMessage::SaslContinue { .. } => {
360 return Err(Error::Protocol(
361 "unexpected SaslContinue outside of SASL flow".into(),
362 ));
363 }
364 AuthenticationMessage::SaslFinal { .. } => {
365 return Err(Error::Protocol(
366 "unexpected SaslFinal outside of SASL flow".into(),
367 ));
368 }
369 },
370 BackendMessage::BackendKeyData {
371 process_id,
372 secret_key,
373 } => {
374 self.process_id = Some(process_id);
375 self.secret_key = Some(secret_key);
376 }
377 BackendMessage::ParameterStatus { name, value } => {
378 tracing::debug!("parameter status: {} = {}", name, value);
379 }
380 BackendMessage::ReadyForQuery { status: _ } => {
381 break;
382 }
383 BackendMessage::ErrorResponse(err) => {
384 crate::metrics::counters::auth_failed(auth_mechanism, "server_error");
385 return Err(Error::Authentication(err.to_string()));
386 }
387 _ => {
388 return Err(Error::Protocol(format!(
389 "unexpected message during auth: {:?}",
390 msg
391 )));
392 }
393 }
394 }
395
396 Ok(())
397 }
398
399 async fn handle_sasl(
401 &mut self,
402 mechanisms: &[String],
403 config: &ConnectionConfig,
404 ) -> Result<()> {
405 if !mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
407 return Err(Error::Authentication(format!(
408 "server does not support SCRAM-SHA-256. Available: {}",
409 mechanisms.join(", ")
410 )));
411 }
412
413 let password = config.password.as_ref().ok_or_else(|| {
415 Error::Authentication("password required for SCRAM authentication".into())
416 })?;
417
418 let mut scram = ScramClient::new(config.user.clone(), password.clone());
420 tracing::debug!("initiating SCRAM-SHA-256 authentication");
421
422 let client_first = scram.client_first();
424 let msg = FrontendMessage::SaslInitialResponse {
425 mechanism: "SCRAM-SHA-256".to_string(),
426 data: client_first.into_bytes(),
427 };
428 self.send_message(&msg).await?;
429
430 let server_first_msg = self.receive_message().await?;
432 let server_first_data = match server_first_msg {
433 BackendMessage::Authentication(AuthenticationMessage::SaslContinue { data }) => data,
434 BackendMessage::ErrorResponse(err) => {
435 return Err(Error::Authentication(format!("SASL server error: {}", err)));
436 }
437 _ => {
438 return Err(Error::Protocol(
439 "expected SaslContinue message during SASL authentication".into(),
440 ));
441 }
442 };
443
444 let server_first = String::from_utf8(server_first_data).map_err(|e| {
445 Error::Authentication(format!("invalid UTF-8 in server first message: {}", e))
446 })?;
447
448 tracing::debug!("received SCRAM server first message");
449
450 let (client_final, scram_state) = scram
452 .client_final(&server_first)
453 .map_err(|e| Error::Authentication(format!("SCRAM error: {}", e)))?;
454
455 let msg = FrontendMessage::SaslResponse {
457 data: client_final.into_bytes(),
458 };
459 self.send_message(&msg).await?;
460
461 let server_final_msg = self.receive_message().await?;
463 let server_final_data = match server_final_msg {
464 BackendMessage::Authentication(AuthenticationMessage::SaslFinal { data }) => data,
465 BackendMessage::ErrorResponse(err) => {
466 return Err(Error::Authentication(format!("SASL server error: {}", err)));
467 }
468 _ => {
469 return Err(Error::Protocol(
470 "expected SaslFinal message during SASL authentication".into(),
471 ));
472 }
473 };
474
475 let server_final = String::from_utf8(server_final_data).map_err(|e| {
476 Error::Authentication(format!("invalid UTF-8 in server final message: {}", e))
477 })?;
478
479 scram
481 .verify_server_final(&server_final, &scram_state)
482 .map_err(|e| Error::Authentication(format!("SCRAM verification failed: {}", e)))?;
483
484 tracing::debug!("SCRAM-SHA-256 authentication successful");
485 Ok(())
486 }
487
488 pub async fn simple_query(&mut self, query: &str) -> Result<Vec<BackendMessage>> {
490 if self.state != ConnectionState::Idle {
491 return Err(Error::ConnectionBusy(format!(
492 "connection in state: {}",
493 self.state
494 )));
495 }
496
497 self.state.transition(ConnectionState::QueryInProgress)?;
498
499 let query_msg = FrontendMessage::Query(query.to_string());
500 self.send_message(&query_msg).await?;
501
502 self.state.transition(ConnectionState::ReadingResults)?;
503
504 let mut messages = Vec::new();
505
506 loop {
507 let msg = self.receive_message().await?;
508 let is_ready = matches!(msg, BackendMessage::ReadyForQuery { .. });
509 messages.push(msg);
510
511 if is_ready {
512 break;
513 }
514 }
515
516 self.state.transition(ConnectionState::Idle)?;
517 Ok(messages)
518 }
519
520 async fn send_message(&mut self, msg: &FrontendMessage) -> Result<()> {
522 let buf = encode_message(msg)?;
523 self.transport.write_all(&buf).await?;
524 self.transport.flush().await?;
525 Ok(())
526 }
527
528 async fn receive_message(&mut self) -> Result<BackendMessage> {
530 loop {
531 if let Ok((msg, consumed)) = decode_message(&mut self.read_buf) {
533 self.read_buf.advance(consumed);
534 return Ok(msg);
535 }
536
537 let n = self.transport.read_buf(&mut self.read_buf).await?;
539 if n == 0 {
540 return Err(Error::ConnectionClosed);
541 }
542 }
543 }
544
545 pub async fn close(mut self) -> Result<()> {
547 self.state.transition(ConnectionState::Closed)?;
548 let _ = self.send_message(&FrontendMessage::Terminate).await;
549 self.transport.shutdown().await?;
550 Ok(())
551 }
552
553 #[allow(clippy::too_many_arguments)]
558 pub async fn streaming_query(
559 mut self,
560 query: &str,
561 chunk_size: usize,
562 max_memory: Option<usize>,
563 soft_limit_warn_threshold: Option<f32>,
564 soft_limit_fail_threshold: Option<f32>,
565 enable_adaptive_chunking: bool,
566 adaptive_min_chunk_size: Option<usize>,
567 adaptive_max_chunk_size: Option<usize>,
568 ) -> Result<crate::stream::JsonStream> {
569 async {
570 let startup_start = std::time::Instant::now();
571
572 use crate::json::validate_row_description;
573 use crate::stream::{extract_json_bytes, parse_json, AdaptiveChunking, ChunkingStrategy, JsonStream};
574 use serde_json::Value;
575 use tokio::sync::mpsc;
576
577 if self.state != ConnectionState::Idle {
578 return Err(Error::ConnectionBusy(format!(
579 "connection in state: {}",
580 self.state
581 )));
582 }
583
584 self.state.transition(ConnectionState::QueryInProgress)?;
585
586 let query_msg = FrontendMessage::Query(query.to_string());
587 self.send_message(&query_msg).await?;
588
589 self.state.transition(ConnectionState::ReadingResults)?;
590
591 let row_desc;
594 loop {
595 let msg = self.receive_message().await?;
596
597 match msg {
598 BackendMessage::ErrorResponse(err) => {
599 tracing::debug!("PostgreSQL error response: {}", err);
601 loop {
602 let msg = self.receive_message().await?;
603 if matches!(msg, BackendMessage::ReadyForQuery { .. }) {
604 break;
605 }
606 }
607 return Err(Error::Sql(err.to_string()));
608 }
609 BackendMessage::BackendKeyData { process_id, secret_key: _ } => {
610 tracing::debug!("PostgreSQL backend key data received: pid={}", process_id);
612 continue;
614 }
615 BackendMessage::ParameterStatus { .. } => {
616 tracing::debug!("PostgreSQL parameter status change received");
618 continue;
619 }
620 BackendMessage::NoticeResponse(notice) => {
621 tracing::debug!("PostgreSQL notice: {}", notice);
623 continue;
624 }
625 BackendMessage::RowDescription(_) => {
626 row_desc = msg;
627 break;
628 }
629 BackendMessage::ReadyForQuery { .. } => {
630 return Err(Error::Protocol(
633 "no result set received from query - \
634 check that the entity name is correct and the table/view exists"
635 .into(),
636 ));
637 }
638 _ => {
639 return Err(Error::Protocol(format!(
640 "unexpected message type in query response: {:?}",
641 msg
642 )));
643 }
644 }
645 }
646
647 validate_row_description(&row_desc)?;
648
649 let startup_duration = startup_start.elapsed().as_millis() as u64;
651 let entity = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
652 crate::metrics::histograms::query_startup_duration(&entity, startup_duration);
653
654 let (result_tx, result_rx) = mpsc::channel::<Result<Value>>(chunk_size);
656 let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
657
658 let entity_for_metrics = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
660 let entity_for_stream = entity_for_metrics.clone(); let stream = JsonStream::new(
663 result_rx,
664 cancel_tx,
665 entity_for_stream,
666 max_memory,
667 soft_limit_warn_threshold,
668 soft_limit_fail_threshold,
669 );
670
671 let state_lock = stream.clone_state();
673 let pause_signal = stream.clone_pause_signal();
674 let resume_signal = stream.clone_resume_signal();
675
676 let state_atomic = stream.clone_state_atomic();
678
679 let pause_timeout = stream.pause_timeout();
681
682 let query_start = std::time::Instant::now();
684
685 tokio::spawn(async move {
686 let strategy = ChunkingStrategy::new(chunk_size);
687 let mut chunk = strategy.new_chunk();
688 let mut total_rows = 0u64;
689
690 let _adaptive = if enable_adaptive_chunking {
692 let mut adp = AdaptiveChunking::new();
693
694 if let Some(min) = adaptive_min_chunk_size {
696 if let Some(max) = adaptive_max_chunk_size {
697 adp = adp.with_bounds(min, max);
698 }
699 }
700
701 Some(adp)
702 } else {
703 None
704 };
705 let _current_chunk_size = chunk_size;
706
707 loop {
708 if state_lock.is_some() && state_atomic.load(std::sync::atomic::Ordering::Acquire) == 1 {
711 if let (Some(ref state_lock), Some(ref _pause_signal), Some(ref resume_signal)) =
713 (&state_lock, &pause_signal, &resume_signal)
714 {
715 let current_state = state_lock.lock().await;
716 if *current_state == crate::stream::StreamState::Paused {
717 tracing::debug!("stream paused, waiting for resume");
718 drop(current_state); if let Some(timeout) = pause_timeout {
722 match tokio::time::timeout(timeout, resume_signal.notified()).await {
723 Ok(_) => {
724 tracing::debug!("stream resumed");
725 }
726 Err(_) => {
727 tracing::debug!("pause timeout expired, auto-resuming");
728 crate::metrics::counters::stream_pause_timeout_expired(&entity_for_metrics);
729 }
730 }
731 } else {
732 resume_signal.notified().await;
734 tracing::debug!("stream resumed");
735 }
736
737 let mut state = state_lock.lock().await;
739 *state = crate::stream::StreamState::Running;
740 }
741 }
742 }
743
744 tokio::select! {
745 _ = cancel_rx.recv() => {
747 tracing::debug!("query cancelled");
748 crate::metrics::counters::query_completed("cancelled", &entity_for_metrics);
749 break;
750 }
751
752 msg_result = self.receive_message() => {
754 match msg_result {
755 Ok(msg) => match msg {
756 BackendMessage::DataRow(_) => {
757 match extract_json_bytes(&msg) {
758 Ok(json_bytes) => {
759 chunk.push(json_bytes);
760
761 if strategy.is_full(&chunk) {
762 let chunk_start = std::time::Instant::now();
763 let rows = chunk.into_rows();
764 let chunk_size_rows = rows.len() as u64;
765
766 const BATCH_SIZE: usize = 8;
769 let mut batch = Vec::with_capacity(BATCH_SIZE);
770 let mut send_error = false;
771
772 for row_bytes in rows {
773 match parse_json(row_bytes) {
774 Ok(value) => {
775 total_rows += 1;
776 batch.push(Ok(value));
777
778 if batch.len() == BATCH_SIZE {
780 for item in batch.drain(..) {
781 if result_tx.send(item).await.is_err() {
782 crate::metrics::counters::query_completed("error", &entity_for_metrics);
783 send_error = true;
784 break;
785 }
786 }
787 if send_error {
788 break;
789 }
790 }
791 }
792 Err(e) => {
793 crate::metrics::counters::json_parse_error(&entity_for_metrics);
794 let _ = result_tx.send(Err(e)).await;
795 crate::metrics::counters::query_completed("error", &entity_for_metrics);
796 send_error = true;
797 break;
798 }
799 }
800 }
801
802 if !send_error {
804 for item in batch {
805 if result_tx.send(item).await.is_err() {
806 crate::metrics::counters::query_completed("error", &entity_for_metrics);
807 break;
808 }
809 }
810 }
811
812 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
814
815 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
817 if chunk_idx % 10 == 0 {
818 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
819 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
820 }
821
822 chunk = strategy.new_chunk();
828 }
829 }
830 Err(e) => {
831 crate::metrics::counters::json_parse_error(&entity_for_metrics);
832 let _ = result_tx.send(Err(e)).await;
833 crate::metrics::counters::query_completed("error", &entity_for_metrics);
834 break;
835 }
836 }
837 }
838 BackendMessage::CommandComplete(_) => {
839 if !chunk.is_empty() {
841 let chunk_start = std::time::Instant::now();
842 let rows = chunk.into_rows();
843 let chunk_size_rows = rows.len() as u64;
844
845 const BATCH_SIZE: usize = 8;
847 let mut batch = Vec::with_capacity(BATCH_SIZE);
848 let mut send_error = false;
849
850 for row_bytes in rows {
851 match parse_json(row_bytes) {
852 Ok(value) => {
853 total_rows += 1;
854 batch.push(Ok(value));
855
856 if batch.len() == BATCH_SIZE {
858 for item in batch.drain(..) {
859 if result_tx.send(item).await.is_err() {
860 crate::metrics::counters::query_completed("error", &entity_for_metrics);
861 send_error = true;
862 break;
863 }
864 }
865 if send_error {
866 break;
867 }
868 }
869 }
870 Err(e) => {
871 crate::metrics::counters::json_parse_error(&entity_for_metrics);
872 let _ = result_tx.send(Err(e)).await;
873 crate::metrics::counters::query_completed("error", &entity_for_metrics);
874 send_error = true;
875 break;
876 }
877 }
878 }
879
880 if !send_error {
882 for item in batch {
883 if result_tx.send(item).await.is_err() {
884 crate::metrics::counters::query_completed("error", &entity_for_metrics);
885 break;
886 }
887 }
888 }
889
890 let chunk_duration = chunk_start.elapsed().as_millis() as u64;
892 let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
893 if chunk_idx % 10 == 0 {
894 crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
895 crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
896 }
897 chunk = strategy.new_chunk();
898 }
899
900 let query_duration = query_start.elapsed().as_millis() as u64;
902 crate::metrics::counters::rows_processed(&entity_for_metrics, total_rows, "ok");
903 crate::metrics::histograms::query_total_duration(&entity_for_metrics, query_duration);
904 crate::metrics::counters::query_completed("success", &entity_for_metrics);
905 }
906 BackendMessage::ReadyForQuery { .. } => {
907 break;
908 }
909 BackendMessage::ErrorResponse(err) => {
910 crate::metrics::counters::query_error(&entity_for_metrics, "server_error");
911 crate::metrics::counters::query_completed("error", &entity_for_metrics);
912 let _ = result_tx.send(Err(Error::Sql(err.to_string()))).await;
913 break;
914 }
915 _ => {
916 crate::metrics::counters::query_error(&entity_for_metrics, "protocol_error");
917 crate::metrics::counters::query_completed("error", &entity_for_metrics);
918 let _ = result_tx.send(Err(Error::Protocol(
919 format!("unexpected message: {:?}", msg)
920 ))).await;
921 break;
922 }
923 },
924 Err(e) => {
925 crate::metrics::counters::query_error(&entity_for_metrics, "connection_error");
926 crate::metrics::counters::query_completed("error", &entity_for_metrics);
927 let _ = result_tx.send(Err(e)).await;
928 break;
929 }
930 }
931 }
932 }
933 }
934 });
935
936 Ok(stream)
937 }
938 .instrument(tracing::debug_span!(
939 "streaming_query",
940 query = %query,
941 chunk_size = %chunk_size
942 ))
943 .await
944 }
945}
946
947fn extract_entity_from_query(query: &str) -> Option<String> {
950 let query_lower = query.to_lowercase();
951 if let Some(from_pos) = query_lower.find("from") {
952 let after_from = &query_lower[from_pos + 4..].trim_start();
953 if let Some(entity_start) = after_from.find('v').or_else(|| after_from.find('t')) {
954 let potential_table = &after_from[entity_start..];
955 let end_pos = potential_table
957 .find(' ')
958 .or_else(|| potential_table.find(';'))
959 .unwrap_or(potential_table.len());
960 let table_name = &potential_table[..end_pos];
961 if let Some(entity_pos) = table_name.rfind('_') {
963 return Some(table_name[entity_pos + 1..].to_string());
964 }
965 }
966 }
967 None
968}
969
970#[cfg(test)]
971mod tests {
972 use super::*;
973
974 #[test]
975 fn test_connection_config() {
976 let config = ConnectionConfig::new("testdb", "testuser")
977 .password("testpass")
978 .param("application_name", "fraiseql-wire");
979
980 assert_eq!(config.database, "testdb");
981 assert_eq!(config.user, "testuser");
982 assert_eq!(config.password, Some("testpass".to_string()));
983 assert_eq!(
984 config.params.get("application_name"),
985 Some(&"fraiseql-wire".to_string())
986 );
987 }
988
989 #[test]
990 fn test_connection_config_builder_basic() {
991 let config = ConnectionConfig::builder("mydb", "myuser")
992 .password("mypass")
993 .build();
994
995 assert_eq!(config.database, "mydb");
996 assert_eq!(config.user, "myuser");
997 assert_eq!(config.password, Some("mypass".to_string()));
998 assert_eq!(config.connect_timeout, None);
999 assert_eq!(config.statement_timeout, None);
1000 assert_eq!(config.keepalive_idle, None);
1001 assert_eq!(config.application_name, None);
1002 }
1003
1004 #[test]
1005 fn test_connection_config_builder_with_timeouts() {
1006 let connect_timeout = Duration::from_secs(10);
1007 let statement_timeout = Duration::from_secs(30);
1008 let keepalive_idle = Duration::from_secs(300);
1009
1010 let config = ConnectionConfig::builder("mydb", "myuser")
1011 .password("mypass")
1012 .connect_timeout(connect_timeout)
1013 .statement_timeout(statement_timeout)
1014 .keepalive_idle(keepalive_idle)
1015 .build();
1016
1017 assert_eq!(config.connect_timeout, Some(connect_timeout));
1018 assert_eq!(config.statement_timeout, Some(statement_timeout));
1019 assert_eq!(config.keepalive_idle, Some(keepalive_idle));
1020 }
1021
1022 #[test]
1023 fn test_connection_config_builder_with_application_name() {
1024 let config = ConnectionConfig::builder("mydb", "myuser")
1025 .application_name("my_app")
1026 .extra_float_digits(2)
1027 .build();
1028
1029 assert_eq!(config.application_name, Some("my_app".to_string()));
1030 assert_eq!(config.extra_float_digits, Some(2));
1031 }
1032
1033 #[test]
1034 fn test_connection_config_builder_fluent() {
1035 let config = ConnectionConfig::builder("mydb", "myuser")
1036 .password("secret")
1037 .param("key1", "value1")
1038 .connect_timeout(Duration::from_secs(5))
1039 .statement_timeout(Duration::from_secs(60))
1040 .application_name("test_app")
1041 .build();
1042
1043 assert_eq!(config.database, "mydb");
1044 assert_eq!(config.user, "myuser");
1045 assert_eq!(config.password, Some("secret".to_string()));
1046 assert_eq!(config.params.get("key1"), Some(&"value1".to_string()));
1047 assert_eq!(config.connect_timeout, Some(Duration::from_secs(5)));
1048 assert_eq!(config.statement_timeout, Some(Duration::from_secs(60)));
1049 assert_eq!(config.application_name, Some("test_app".to_string()));
1050 }
1051
1052 #[test]
1053 fn test_connection_config_defaults() {
1054 let config = ConnectionConfig::new("db", "user");
1055
1056 assert!(config.connect_timeout.is_none());
1057 assert!(config.statement_timeout.is_none());
1058 assert!(config.keepalive_idle.is_none());
1059 assert!(config.application_name.is_none());
1060 assert!(config.extra_float_digits.is_none());
1061 }
1062
1063 #[allow(dead_code)]
1067 const _SEND_SAFETY_CHECK: fn() = || {
1068 fn require_send<T: Send>() {}
1069
1070 #[allow(unreachable_code)]
1072 let _ = || {
1073 require_send::<
1075 std::pin::Pin<std::boxed::Box<dyn std::future::Future<Output = ()> + Send>>,
1076 >();
1077 };
1078 };
1079}