Skip to main content

fraiseql_wire/connection/
conn.rs

1//! Core connection type
2
3use super::state::ConnectionState;
4use super::transport::Transport;
5use crate::auth::ScramClient;
6use crate::protocol::{
7    decode_message, encode_message, AuthenticationMessage, BackendMessage, FrontendMessage,
8};
9use crate::{Error, Result};
10use bytes::{Buf, BytesMut};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::time::Duration;
14use tracing::Instrument;
15use zeroize::Zeroizing;
16
17// Global counter for chunk metrics sampling (1 per 10 chunks)
18// Used to reduce per-chunk metric recording overhead
19static CHUNK_COUNT: AtomicU64 = AtomicU64::new(0);
20
21/// Connection configuration
22///
23/// Stores connection parameters including database, credentials, and optional timeouts.
24/// Use `ConnectionConfig::builder()` for advanced configuration with timeouts and keepalive.
25#[derive(Debug, Clone)]
26pub struct ConnectionConfig {
27    /// Database name
28    pub database: String,
29    /// Username
30    pub user: String,
31    /// Password (optional, zeroed on drop for security)
32    pub password: Option<Zeroizing<String>>,
33    /// Additional connection parameters
34    pub params: HashMap<String, String>,
35    /// TCP connection timeout (default: 10 seconds)
36    pub connect_timeout: Option<Duration>,
37    /// Query statement timeout
38    pub statement_timeout: Option<Duration>,
39    /// TCP keepalive idle interval (default: 5 minutes)
40    pub keepalive_idle: Option<Duration>,
41    /// Application name for Postgres logs (default: "fraiseql-wire")
42    pub application_name: Option<String>,
43    /// Postgres extra_float_digits setting
44    pub extra_float_digits: Option<i32>,
45}
46
47impl ConnectionConfig {
48    /// Create new configuration with defaults
49    ///
50    /// # Arguments
51    ///
52    /// * `database` - Database name
53    /// * `user` - Username
54    ///
55    /// # Defaults
56    ///
57    /// - `connect_timeout`: None
58    /// - `statement_timeout`: None
59    /// - `keepalive_idle`: None
60    /// - `application_name`: None
61    /// - `extra_float_digits`: None
62    ///
63    /// For configured timeouts and keepalive, use `builder()` instead.
64    pub fn new(database: impl Into<String>, user: impl Into<String>) -> Self {
65        Self {
66            database: database.into(),
67            user: user.into(),
68            password: None,
69            params: HashMap::new(),
70            connect_timeout: None,
71            statement_timeout: None,
72            keepalive_idle: None,
73            application_name: None,
74            extra_float_digits: None,
75        }
76    }
77
78    /// Create a builder for advanced configuration
79    ///
80    /// Use this to configure timeouts, keepalive, and application name.
81    ///
82    /// # Examples
83    ///
84    /// ```ignore
85    /// let config = ConnectionConfig::builder("mydb", "user")
86    ///     .connect_timeout(Duration::from_secs(10))
87    ///     .statement_timeout(Duration::from_secs(30))
88    ///     .build();
89    /// ```
90    pub fn builder(
91        database: impl Into<String>,
92        user: impl Into<String>,
93    ) -> ConnectionConfigBuilder {
94        ConnectionConfigBuilder {
95            database: database.into(),
96            user: user.into(),
97            password: None,
98            params: HashMap::new(),
99            connect_timeout: None,
100            statement_timeout: None,
101            keepalive_idle: None,
102            application_name: None,
103            extra_float_digits: None,
104        }
105    }
106
107    /// Set password (automatically zeroed on drop)
108    pub fn password(mut self, password: impl Into<String>) -> Self {
109        self.password = Some(Zeroizing::new(password.into()));
110        self
111    }
112
113    /// Add connection parameter
114    pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
115        self.params.insert(key.into(), value.into());
116        self
117    }
118}
119
120/// Builder for creating `ConnectionConfig` with advanced options
121///
122/// Provides a fluent API for configuring timeouts, keepalive, and application name.
123///
124/// # Examples
125///
126/// ```ignore
127/// let config = ConnectionConfig::builder("mydb", "user")
128///     .password("secret")
129///     .connect_timeout(Duration::from_secs(10))
130///     .statement_timeout(Duration::from_secs(30))
131///     .keepalive_idle(Duration::from_secs(300))
132///     .application_name("my_app")
133///     .build();
134/// ```
135#[derive(Debug, Clone)]
136pub struct ConnectionConfigBuilder {
137    database: String,
138    user: String,
139    password: Option<Zeroizing<String>>,
140    params: HashMap<String, String>,
141    connect_timeout: Option<Duration>,
142    statement_timeout: Option<Duration>,
143    keepalive_idle: Option<Duration>,
144    application_name: Option<String>,
145    extra_float_digits: Option<i32>,
146}
147
148impl ConnectionConfigBuilder {
149    /// Set the password (automatically zeroed on drop)
150    pub fn password(mut self, password: impl Into<String>) -> Self {
151        self.password = Some(Zeroizing::new(password.into()));
152        self
153    }
154
155    /// Add a connection parameter
156    pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
157        self.params.insert(key.into(), value.into());
158        self
159    }
160
161    /// Set TCP connection timeout
162    ///
163    /// Default: None (no timeout)
164    ///
165    /// # Arguments
166    ///
167    /// * `duration` - Timeout duration for establishing TCP connection
168    pub fn connect_timeout(mut self, duration: Duration) -> Self {
169        self.connect_timeout = Some(duration);
170        self
171    }
172
173    /// Set statement (query) timeout
174    ///
175    /// Default: None (unlimited)
176    ///
177    /// # Arguments
178    ///
179    /// * `duration` - Timeout duration for query execution
180    pub fn statement_timeout(mut self, duration: Duration) -> Self {
181        self.statement_timeout = Some(duration);
182        self
183    }
184
185    /// Set TCP keepalive idle interval
186    ///
187    /// Default: None (OS default)
188    ///
189    /// # Arguments
190    ///
191    /// * `duration` - Idle duration before sending keepalive probes
192    pub fn keepalive_idle(mut self, duration: Duration) -> Self {
193        self.keepalive_idle = Some(duration);
194        self
195    }
196
197    /// Set application name for Postgres logs
198    ///
199    /// Default: None (Postgres will not set application_name)
200    ///
201    /// # Arguments
202    ///
203    /// * `name` - Application name to identify in Postgres logs
204    pub fn application_name(mut self, name: impl Into<String>) -> Self {
205        self.application_name = Some(name.into());
206        self
207    }
208
209    /// Set extra_float_digits for float precision
210    ///
211    /// Default: None (use Postgres default)
212    ///
213    /// # Arguments
214    ///
215    /// * `digits` - Number of extra digits (typically 0-2)
216    pub fn extra_float_digits(mut self, digits: i32) -> Self {
217        self.extra_float_digits = Some(digits);
218        self
219    }
220
221    /// Build the configuration
222    pub fn build(self) -> ConnectionConfig {
223        ConnectionConfig {
224            database: self.database,
225            user: self.user,
226            password: self.password,
227            params: self.params,
228            connect_timeout: self.connect_timeout,
229            statement_timeout: self.statement_timeout,
230            keepalive_idle: self.keepalive_idle,
231            application_name: self.application_name,
232            extra_float_digits: self.extra_float_digits,
233        }
234    }
235}
236
237/// Postgres connection
238pub struct Connection {
239    transport: Transport,
240    state: ConnectionState,
241    read_buf: BytesMut,
242    process_id: Option<i32>,
243    secret_key: Option<i32>,
244}
245
246impl Connection {
247    /// Create connection from transport
248    pub fn new(transport: Transport) -> Self {
249        Self {
250            transport,
251            state: ConnectionState::Initial,
252            read_buf: BytesMut::with_capacity(8192),
253            process_id: None,
254            secret_key: None,
255        }
256    }
257
258    /// Get current connection state
259    pub fn state(&self) -> ConnectionState {
260        self.state
261    }
262
263    /// Perform startup and authentication
264    pub async fn startup(&mut self, config: &ConnectionConfig) -> Result<()> {
265        async {
266            self.state.transition(ConnectionState::AwaitingAuth)?;
267
268            // Build startup parameters
269            let mut params = vec![
270                ("user".to_string(), config.user.clone()),
271                ("database".to_string(), config.database.clone()),
272            ];
273
274            // Add configured application name if specified
275            if let Some(app_name) = &config.application_name {
276                params.push(("application_name".to_string(), app_name.clone()));
277            }
278
279            // Add statement timeout if specified (in milliseconds)
280            if let Some(timeout) = config.statement_timeout {
281                params.push((
282                    "statement_timeout".to_string(),
283                    timeout.as_millis().to_string(),
284                ));
285            }
286
287            // Add extra_float_digits if specified
288            if let Some(digits) = config.extra_float_digits {
289                params.push(("extra_float_digits".to_string(), digits.to_string()));
290            }
291
292            // Add user-provided parameters
293            for (k, v) in &config.params {
294                params.push((k.clone(), v.clone()));
295            }
296
297            // Send startup message
298            let startup = FrontendMessage::Startup {
299                version: crate::protocol::constants::PROTOCOL_VERSION,
300                params,
301            };
302            self.send_message(&startup).await?;
303
304            // Authentication loop
305            self.state.transition(ConnectionState::Authenticating)?;
306            self.authenticate(config).await?;
307
308            self.state.transition(ConnectionState::Idle)?;
309            tracing::info!("startup complete");
310            Ok(())
311        }
312        .instrument(tracing::info_span!(
313            "startup",
314            user = %config.user,
315            database = %config.database
316        ))
317        .await
318    }
319
320    /// Handle authentication
321    async fn authenticate(&mut self, config: &ConnectionConfig) -> Result<()> {
322        let auth_start = std::time::Instant::now();
323        let mut auth_mechanism = "unknown";
324
325        loop {
326            let msg = self.receive_message().await?;
327
328            match msg {
329                BackendMessage::Authentication(auth) => match auth {
330                    AuthenticationMessage::Ok => {
331                        tracing::debug!("authentication successful");
332                        crate::metrics::counters::auth_successful(auth_mechanism);
333                        crate::metrics::histograms::auth_duration(
334                            auth_mechanism,
335                            auth_start.elapsed().as_millis() as u64,
336                        );
337                        // Don't break here! Must continue reading until ReadyForQuery
338                    }
339                    AuthenticationMessage::CleartextPassword => {
340                        auth_mechanism = crate::metrics::labels::MECHANISM_CLEARTEXT;
341                        crate::metrics::counters::auth_attempted(auth_mechanism);
342
343                        let password = config
344                            .password
345                            .as_ref()
346                            .ok_or_else(|| Error::Authentication("password required".into()))?;
347                        // SECURITY: Convert from Zeroizing wrapper while preserving password content
348                        let pwd_msg = FrontendMessage::Password(password.as_str().to_string());
349                        self.send_message(&pwd_msg).await?;
350                    }
351                    AuthenticationMessage::Md5Password { .. } => {
352                        return Err(Error::Authentication(
353                            "MD5 authentication not supported. Use SCRAM-SHA-256 or cleartext password".into(),
354                        ));
355                    }
356                    AuthenticationMessage::Sasl { mechanisms } => {
357                        auth_mechanism = crate::metrics::labels::MECHANISM_SCRAM;
358                        crate::metrics::counters::auth_attempted(auth_mechanism);
359                        self.handle_sasl(&mechanisms, config).await?;
360                    }
361                    AuthenticationMessage::SaslContinue { .. } => {
362                        return Err(Error::Protocol(
363                            "unexpected SaslContinue outside of SASL flow".into(),
364                        ));
365                    }
366                    AuthenticationMessage::SaslFinal { .. } => {
367                        return Err(Error::Protocol(
368                            "unexpected SaslFinal outside of SASL flow".into(),
369                        ));
370                    }
371                },
372                BackendMessage::BackendKeyData {
373                    process_id,
374                    secret_key,
375                } => {
376                    self.process_id = Some(process_id);
377                    self.secret_key = Some(secret_key);
378                }
379                BackendMessage::ParameterStatus { name, value } => {
380                    tracing::debug!("parameter status: {} = {}", name, value);
381                }
382                BackendMessage::ReadyForQuery { status: _ } => {
383                    break;
384                }
385                BackendMessage::ErrorResponse(err) => {
386                    crate::metrics::counters::auth_failed(auth_mechanism, "server_error");
387                    return Err(Error::Authentication(err.to_string()));
388                }
389                _ => {
390                    return Err(Error::Protocol(format!(
391                        "unexpected message during auth: {:?}",
392                        msg
393                    )));
394                }
395            }
396        }
397
398        Ok(())
399    }
400
401    /// Handle SASL authentication (SCRAM-SHA-256)
402    async fn handle_sasl(
403        &mut self,
404        mechanisms: &[String],
405        config: &ConnectionConfig,
406    ) -> Result<()> {
407        // Check if server supports SCRAM-SHA-256
408        if !mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
409            return Err(Error::Authentication(format!(
410                "server does not support SCRAM-SHA-256. Available: {}",
411                mechanisms.join(", ")
412            )));
413        }
414
415        // Get password
416        let password = config.password.as_ref().ok_or_else(|| {
417            Error::Authentication("password required for SCRAM authentication".into())
418        })?;
419
420        // Create SCRAM client
421        // SECURITY: Convert from Zeroizing wrapper while preserving password content
422        let mut scram = ScramClient::new(config.user.clone(), password.as_str().to_string());
423        tracing::debug!("initiating SCRAM-SHA-256 authentication");
424
425        // Send SaslInitialResponse with client first message
426        let client_first = scram.client_first();
427        let msg = FrontendMessage::SaslInitialResponse {
428            mechanism: "SCRAM-SHA-256".to_string(),
429            data: client_first.into_bytes(),
430        };
431        self.send_message(&msg).await?;
432
433        // Receive SaslContinue with server first message
434        let server_first_msg = self.receive_message().await?;
435        let server_first_data = match server_first_msg {
436            BackendMessage::Authentication(AuthenticationMessage::SaslContinue { data }) => data,
437            BackendMessage::ErrorResponse(err) => {
438                return Err(Error::Authentication(format!("SASL server error: {}", err)));
439            }
440            _ => {
441                return Err(Error::Protocol(
442                    "expected SaslContinue message during SASL authentication".into(),
443                ));
444            }
445        };
446
447        let server_first = String::from_utf8(server_first_data).map_err(|e| {
448            Error::Authentication(format!("invalid UTF-8 in server first message: {}", e))
449        })?;
450
451        tracing::debug!("received SCRAM server first message");
452
453        // Generate client final message
454        let (client_final, scram_state) = scram
455            .client_final(&server_first)
456            .map_err(|e| Error::Authentication(format!("SCRAM error: {}", e)))?;
457
458        // Send SaslResponse with client final message
459        let msg = FrontendMessage::SaslResponse {
460            data: client_final.into_bytes(),
461        };
462        self.send_message(&msg).await?;
463
464        // Receive SaslFinal with server verification
465        let server_final_msg = self.receive_message().await?;
466        let server_final_data = match server_final_msg {
467            BackendMessage::Authentication(AuthenticationMessage::SaslFinal { data }) => data,
468            BackendMessage::ErrorResponse(err) => {
469                return Err(Error::Authentication(format!("SASL server error: {}", err)));
470            }
471            _ => {
472                return Err(Error::Protocol(
473                    "expected SaslFinal message during SASL authentication".into(),
474                ));
475            }
476        };
477
478        let server_final = String::from_utf8(server_final_data).map_err(|e| {
479            Error::Authentication(format!("invalid UTF-8 in server final message: {}", e))
480        })?;
481
482        // Verify server signature
483        scram
484            .verify_server_final(&server_final, &scram_state)
485            .map_err(|e| Error::Authentication(format!("SCRAM verification failed: {}", e)))?;
486
487        tracing::debug!("SCRAM-SHA-256 authentication successful");
488        Ok(())
489    }
490
491    /// Execute a simple query (returns all backend messages)
492    pub async fn simple_query(&mut self, query: &str) -> Result<Vec<BackendMessage>> {
493        if self.state != ConnectionState::Idle {
494            return Err(Error::ConnectionBusy(format!(
495                "connection in state: {}",
496                self.state
497            )));
498        }
499
500        self.state.transition(ConnectionState::QueryInProgress)?;
501
502        let query_msg = FrontendMessage::Query(query.to_string());
503        self.send_message(&query_msg).await?;
504
505        self.state.transition(ConnectionState::ReadingResults)?;
506
507        let mut messages = Vec::new();
508
509        loop {
510            let msg = self.receive_message().await?;
511            let is_ready = matches!(msg, BackendMessage::ReadyForQuery { .. });
512            messages.push(msg);
513
514            if is_ready {
515                break;
516            }
517        }
518
519        self.state.transition(ConnectionState::Idle)?;
520        Ok(messages)
521    }
522
523    /// Send a frontend message
524    async fn send_message(&mut self, msg: &FrontendMessage) -> Result<()> {
525        let buf = encode_message(msg)?;
526        self.transport.write_all(&buf).await?;
527        self.transport.flush().await?;
528        Ok(())
529    }
530
531    /// Receive a backend message
532    async fn receive_message(&mut self) -> Result<BackendMessage> {
533        loop {
534            // Try to decode a message from buffer (without cloning!)
535            if let Ok((msg, consumed)) = decode_message(&mut self.read_buf) {
536                self.read_buf.advance(consumed);
537                return Ok(msg);
538            }
539
540            // Need more data
541            let n = self.transport.read_buf(&mut self.read_buf).await?;
542            if n == 0 {
543                return Err(Error::ConnectionClosed);
544            }
545        }
546    }
547
548    /// Close the connection
549    pub async fn close(mut self) -> Result<()> {
550        self.state.transition(ConnectionState::Closed)?;
551        let _ = self.send_message(&FrontendMessage::Terminate).await;
552        self.transport.shutdown().await?;
553        Ok(())
554    }
555
556    /// Execute a streaming query
557    ///
558    /// Note: This method consumes the connection. The stream maintains the connection
559    /// internally. Once the stream is exhausted or dropped, the connection is closed.
560    #[allow(clippy::too_many_arguments)]
561    pub async fn streaming_query(
562        mut self,
563        query: &str,
564        chunk_size: usize,
565        max_memory: Option<usize>,
566        soft_limit_warn_threshold: Option<f32>,
567        soft_limit_fail_threshold: Option<f32>,
568        enable_adaptive_chunking: bool,
569        adaptive_min_chunk_size: Option<usize>,
570        adaptive_max_chunk_size: Option<usize>,
571    ) -> Result<crate::stream::JsonStream> {
572        async {
573            let startup_start = std::time::Instant::now();
574
575            use crate::json::validate_row_description;
576            use crate::stream::{extract_json_bytes, parse_json, AdaptiveChunking, ChunkingStrategy, JsonStream};
577            use serde_json::Value;
578            use tokio::sync::mpsc;
579
580            if self.state != ConnectionState::Idle {
581                return Err(Error::ConnectionBusy(format!(
582                    "connection in state: {}",
583                    self.state
584                )));
585            }
586
587            self.state.transition(ConnectionState::QueryInProgress)?;
588
589            let query_msg = FrontendMessage::Query(query.to_string());
590            self.send_message(&query_msg).await?;
591
592            self.state.transition(ConnectionState::ReadingResults)?;
593
594            // Read RowDescription, but handle other messages that may come first
595            // (e.g., ParameterStatus, BackendKeyData, ErrorResponse, NoticeResponse)
596            let row_desc;
597            loop {
598                let msg = self.receive_message().await?;
599
600                match msg {
601                    BackendMessage::ErrorResponse(err) => {
602                        // Query failed - consume ReadyForQuery and return error
603                        tracing::debug!("PostgreSQL error response: {}", err);
604                        loop {
605                            let msg = self.receive_message().await?;
606                            if matches!(msg, BackendMessage::ReadyForQuery { .. }) {
607                                break;
608                            }
609                        }
610                        return Err(Error::Sql(err.to_string()));
611                    }
612                    BackendMessage::BackendKeyData { process_id, secret_key: _ } => {
613                        // This provides the key needed for cancel requests - store it and continue
614                        tracing::debug!("PostgreSQL backend key data received: pid={}", process_id);
615                        // Note: We would store this if we need to support cancellation
616                        continue;
617                    }
618                    BackendMessage::ParameterStatus { .. } => {
619                        // Parameter status changes are informational - skip them
620                        tracing::debug!("PostgreSQL parameter status change received");
621                        continue;
622                    }
623                    BackendMessage::NoticeResponse(notice) => {
624                        // Notices are non-fatal warnings - skip them
625                        tracing::debug!("PostgreSQL notice: {}", notice);
626                        continue;
627                    }
628                    BackendMessage::RowDescription(_) => {
629                        row_desc = msg;
630                        break;
631                    }
632                    BackendMessage::ReadyForQuery { .. } => {
633                        // Received ReadyForQuery without RowDescription
634                        // This means the query didn't produce a result set
635                        return Err(Error::Protocol(
636                            "no result set received from query - \
637                             check that the entity name is correct and the table/view exists"
638                                .into(),
639                        ));
640                    }
641                    _ => {
642                        return Err(Error::Protocol(format!(
643                            "unexpected message type in query response: {:?}",
644                            msg
645                        )));
646                    }
647                }
648            }
649
650            validate_row_description(&row_desc)?;
651
652            // Record startup timing
653            let startup_duration = startup_start.elapsed().as_millis() as u64;
654            let entity = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
655            crate::metrics::histograms::query_startup_duration(&entity, startup_duration);
656
657            // Create channels
658            let (result_tx, result_rx) = mpsc::channel::<Result<Value>>(chunk_size);
659            let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
660
661            // Create stream instance first so we can clone its pause/resume signals
662            let entity_for_metrics = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
663            let entity_for_stream = entity_for_metrics.clone();  // Clone for stream
664
665            let stream = JsonStream::new(
666                result_rx,
667                cancel_tx,
668                entity_for_stream,
669                max_memory,
670                soft_limit_warn_threshold,
671                soft_limit_fail_threshold,
672            );
673
674            // Clone pause/resume signals for background task (only if pause/resume is initialized)
675            let state_lock = stream.clone_state();
676            let pause_signal = stream.clone_pause_signal();
677            let resume_signal = stream.clone_resume_signal();
678
679            // Clone atomic state for fast state checks in background task
680            let state_atomic = stream.clone_state_atomic();
681
682            // Clone pause timeout for background task
683            let pause_timeout = stream.pause_timeout();
684
685            // Spawn background task to read rows
686            let query_start = std::time::Instant::now();
687
688            tokio::spawn(async move {
689                let strategy = ChunkingStrategy::new(chunk_size);
690                let mut chunk = strategy.new_chunk();
691                let mut total_rows = 0u64;
692
693            // Initialize adaptive chunking if enabled
694            let _adaptive = if enable_adaptive_chunking {
695                let mut adp = AdaptiveChunking::new();
696
697                // Apply custom bounds if provided
698                if let Some(min) = adaptive_min_chunk_size {
699                    if let Some(max) = adaptive_max_chunk_size {
700                        adp = adp.with_bounds(min, max);
701                    }
702                }
703
704                Some(adp)
705            } else {
706                None
707            };
708            let _current_chunk_size = chunk_size;
709
710            loop {
711                // Check lightweight atomic state first (fast path)
712                // Only check atomic if pause/resume infrastructure is actually initialized
713                if state_lock.is_some() && state_atomic.load(std::sync::atomic::Ordering::Acquire) == 1 {
714                    // Paused state detected via atomic, now handle with Mutex
715                    if let (Some(ref state_lock), Some(ref _pause_signal), Some(ref resume_signal)) =
716                        (&state_lock, &pause_signal, &resume_signal)
717                    {
718                        let current_state = state_lock.lock().await;
719                        if *current_state == crate::stream::StreamState::Paused {
720                            tracing::debug!("stream paused, waiting for resume");
721                            drop(current_state); // Release lock before waiting
722
723                            // Wait with optional timeout
724                            if let Some(timeout) = pause_timeout {
725                                match tokio::time::timeout(timeout, resume_signal.notified()).await {
726                                    Ok(_) => {
727                                        tracing::debug!("stream resumed");
728                                    }
729                                    Err(_) => {
730                                        tracing::debug!("pause timeout expired, auto-resuming");
731                                        crate::metrics::counters::stream_pause_timeout_expired(&entity_for_metrics);
732                                    }
733                                }
734                            } else {
735                                // No timeout, wait indefinitely
736                                resume_signal.notified().await;
737                                tracing::debug!("stream resumed");
738                            }
739
740                            // Update state back to Running
741                            let mut state = state_lock.lock().await;
742                            *state = crate::stream::StreamState::Running;
743                        }
744                    }
745                }
746
747                tokio::select! {
748                    // Check for cancellation
749                    _ = cancel_rx.recv() => {
750                        tracing::debug!("query cancelled");
751                        crate::metrics::counters::query_completed("cancelled", &entity_for_metrics);
752                        break;
753                    }
754
755                    // Read next message
756                    msg_result = self.receive_message() => {
757                        match msg_result {
758                            Ok(msg) => match msg {
759                                BackendMessage::DataRow(_) => {
760                                    match extract_json_bytes(&msg) {
761                                        Ok(json_bytes) => {
762                                            chunk.push(json_bytes);
763
764                                            if strategy.is_full(&chunk) {
765                                                let chunk_start = std::time::Instant::now();
766                                                let rows = chunk.into_rows();
767                                                let chunk_size_rows = rows.len() as u64;
768
769                                                // Batch JSON parsing and sending to reduce lock contention
770                                                // Send 8 values per channel send instead of 1 (8x fewer locks)
771                                                const BATCH_SIZE: usize = 8;
772                                                let mut batch = Vec::with_capacity(BATCH_SIZE);
773                                                let mut send_error = false;
774
775                                                for row_bytes in rows {
776                                                    match parse_json(row_bytes) {
777                                                        Ok(value) => {
778                                                            total_rows += 1;
779                                                            batch.push(Ok(value));
780
781                                                            // Send batch when full
782                                                            if batch.len() == BATCH_SIZE {
783                                                                for item in batch.drain(..) {
784                                                                    if result_tx.send(item).await.is_err() {
785                                                                        crate::metrics::counters::query_completed("error", &entity_for_metrics);
786                                                                        send_error = true;
787                                                                        break;
788                                                                    }
789                                                                }
790                                                                if send_error {
791                                                                    break;
792                                                                }
793                                                            }
794                                                        }
795                                                        Err(e) => {
796                                                            crate::metrics::counters::json_parse_error(&entity_for_metrics);
797                                                            let _ = result_tx.send(Err(e)).await;
798                                                            crate::metrics::counters::query_completed("error", &entity_for_metrics);
799                                                            send_error = true;
800                                                            break;
801                                                        }
802                                                    }
803                                                }
804
805                                                // Send remaining batch items
806                                                if !send_error {
807                                                    for item in batch {
808                                                        if result_tx.send(item).await.is_err() {
809                                                            crate::metrics::counters::query_completed("error", &entity_for_metrics);
810                                                            break;
811                                                        }
812                                                    }
813                                                }
814
815                                                // Record chunk metrics (sampled, not per-chunk)
816                                                let chunk_duration = chunk_start.elapsed().as_millis() as u64;
817
818                                                // Only record metrics every 10 chunks to reduce overhead
819                                                let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
820                                                if chunk_idx % 10 == 0 {
821                                                    crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
822                                                    crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
823                                                }
824
825                                                // Adaptive chunking: disabled by default for better performance
826                                                // Enable only if explicitly requested via enable_adaptive_chunking parameter
827                                                // Note: adaptive adjustment adds ~0.5-1% overhead per chunk
828                                                // For fixed chunk sizes (default), skip this entirely
829
830                                                chunk = strategy.new_chunk();
831                                            }
832                                        }
833                                        Err(e) => {
834                                            crate::metrics::counters::json_parse_error(&entity_for_metrics);
835                                            let _ = result_tx.send(Err(e)).await;
836                                            crate::metrics::counters::query_completed("error", &entity_for_metrics);
837                                            break;
838                                        }
839                                    }
840                                }
841                                BackendMessage::CommandComplete(_) => {
842                                    // Send remaining chunk
843                                    if !chunk.is_empty() {
844                                        let chunk_start = std::time::Instant::now();
845                                        let rows = chunk.into_rows();
846                                        let chunk_size_rows = rows.len() as u64;
847
848                                        // Batch JSON parsing and sending to reduce lock contention
849                                        const BATCH_SIZE: usize = 8;
850                                        let mut batch = Vec::with_capacity(BATCH_SIZE);
851                                        let mut send_error = false;
852
853                                        for row_bytes in rows {
854                                            match parse_json(row_bytes) {
855                                                Ok(value) => {
856                                                    total_rows += 1;
857                                                    batch.push(Ok(value));
858
859                                                    // Send batch when full
860                                                    if batch.len() == BATCH_SIZE {
861                                                        for item in batch.drain(..) {
862                                                            if result_tx.send(item).await.is_err() {
863                                                                crate::metrics::counters::query_completed("error", &entity_for_metrics);
864                                                                send_error = true;
865                                                                break;
866                                                            }
867                                                        }
868                                                        if send_error {
869                                                            break;
870                                                        }
871                                                    }
872                                                }
873                                                Err(e) => {
874                                                    crate::metrics::counters::json_parse_error(&entity_for_metrics);
875                                                    let _ = result_tx.send(Err(e)).await;
876                                                    crate::metrics::counters::query_completed("error", &entity_for_metrics);
877                                                    send_error = true;
878                                                    break;
879                                                }
880                                            }
881                                        }
882
883                                        // Send remaining batch items
884                                        if !send_error {
885                                            for item in batch {
886                                                if result_tx.send(item).await.is_err() {
887                                                    crate::metrics::counters::query_completed("error", &entity_for_metrics);
888                                                    break;
889                                                }
890                                            }
891                                        }
892
893                                        // Record final chunk metrics (sampled)
894                                        let chunk_duration = chunk_start.elapsed().as_millis() as u64;
895                                        let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
896                                        if chunk_idx % 10 == 0 {
897                                            crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
898                                            crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
899                                        }
900                                        chunk = strategy.new_chunk();
901                                    }
902
903                                    // Record query completion metrics
904                                    let query_duration = query_start.elapsed().as_millis() as u64;
905                                    crate::metrics::counters::rows_processed(&entity_for_metrics, total_rows, "ok");
906                                    crate::metrics::histograms::query_total_duration(&entity_for_metrics, query_duration);
907                                    crate::metrics::counters::query_completed("success", &entity_for_metrics);
908                                }
909                                BackendMessage::ReadyForQuery { .. } => {
910                                    break;
911                                }
912                                BackendMessage::ErrorResponse(err) => {
913                                    crate::metrics::counters::query_error(&entity_for_metrics, "server_error");
914                                    crate::metrics::counters::query_completed("error", &entity_for_metrics);
915                                    let _ = result_tx.send(Err(Error::Sql(err.to_string()))).await;
916                                    break;
917                                }
918                                _ => {
919                                    crate::metrics::counters::query_error(&entity_for_metrics, "protocol_error");
920                                    crate::metrics::counters::query_completed("error", &entity_for_metrics);
921                                    let _ = result_tx.send(Err(Error::Protocol(
922                                        format!("unexpected message: {:?}", msg)
923                                    ))).await;
924                                    break;
925                                }
926                            },
927                            Err(e) => {
928                                crate::metrics::counters::query_error(&entity_for_metrics, "connection_error");
929                                crate::metrics::counters::query_completed("error", &entity_for_metrics);
930                                let _ = result_tx.send(Err(e)).await;
931                                break;
932                            }
933                        }
934                    }
935                }
936            }
937            });
938
939            Ok(stream)
940        }
941        .instrument(tracing::debug_span!(
942            "streaming_query",
943            query = %query,
944            chunk_size = %chunk_size
945        ))
946        .await
947    }
948}
949
950/// Extract entity name from query for metrics
951/// Query format: SELECT data FROM v_{entity} ...
952fn extract_entity_from_query(query: &str) -> Option<String> {
953    let query_lower = query.to_lowercase();
954    if let Some(from_pos) = query_lower.find("from") {
955        let after_from = &query_lower[from_pos + 4..].trim_start();
956        if let Some(entity_start) = after_from.find('v').or_else(|| after_from.find('t')) {
957            let potential_table = &after_from[entity_start..];
958            // Extract table name: "v_entity" or "tv_entity"
959            let end_pos = potential_table
960                .find(' ')
961                .or_else(|| potential_table.find(';'))
962                .unwrap_or(potential_table.len());
963            let table_name = &potential_table[..end_pos];
964            // Extract entity from table name
965            if let Some(entity_pos) = table_name.rfind('_') {
966                return Some(table_name[entity_pos + 1..].to_string());
967            }
968        }
969    }
970    None
971}
972
973#[cfg(test)]
974mod tests {
975    use super::*;
976
977    #[test]
978    fn test_connection_config() {
979        let config = ConnectionConfig::new("testdb", "testuser")
980            .password("testpass")
981            .param("application_name", "fraiseql-wire");
982
983        assert_eq!(config.database, "testdb");
984        assert_eq!(config.user, "testuser");
985        assert_eq!(
986            config.password.as_ref().map(|p| p.as_str()),
987            Some("testpass")
988        );
989        assert_eq!(
990            config.params.get("application_name"),
991            Some(&"fraiseql-wire".to_string())
992        );
993    }
994
995    #[test]
996    fn test_connection_config_builder_basic() {
997        let config = ConnectionConfig::builder("mydb", "myuser")
998            .password("mypass")
999            .build();
1000
1001        assert_eq!(config.database, "mydb");
1002        assert_eq!(config.user, "myuser");
1003        assert_eq!(config.password.as_ref().map(|p| p.as_str()), Some("mypass"));
1004        assert_eq!(config.connect_timeout, None);
1005        assert_eq!(config.statement_timeout, None);
1006        assert_eq!(config.keepalive_idle, None);
1007        assert_eq!(config.application_name, None);
1008    }
1009
1010    #[test]
1011    fn test_connection_config_builder_with_timeouts() {
1012        let connect_timeout = Duration::from_secs(10);
1013        let statement_timeout = Duration::from_secs(30);
1014        let keepalive_idle = Duration::from_secs(300);
1015
1016        let config = ConnectionConfig::builder("mydb", "myuser")
1017            .password("mypass")
1018            .connect_timeout(connect_timeout)
1019            .statement_timeout(statement_timeout)
1020            .keepalive_idle(keepalive_idle)
1021            .build();
1022
1023        assert_eq!(config.connect_timeout, Some(connect_timeout));
1024        assert_eq!(config.statement_timeout, Some(statement_timeout));
1025        assert_eq!(config.keepalive_idle, Some(keepalive_idle));
1026    }
1027
1028    #[test]
1029    fn test_connection_config_builder_with_application_name() {
1030        let config = ConnectionConfig::builder("mydb", "myuser")
1031            .application_name("my_app")
1032            .extra_float_digits(2)
1033            .build();
1034
1035        assert_eq!(config.application_name, Some("my_app".to_string()));
1036        assert_eq!(config.extra_float_digits, Some(2));
1037    }
1038
1039    #[test]
1040    fn test_connection_config_builder_fluent() {
1041        let config = ConnectionConfig::builder("mydb", "myuser")
1042            .password("secret")
1043            .param("key1", "value1")
1044            .connect_timeout(Duration::from_secs(5))
1045            .statement_timeout(Duration::from_secs(60))
1046            .application_name("test_app")
1047            .build();
1048
1049        assert_eq!(config.database, "mydb");
1050        assert_eq!(config.user, "myuser");
1051        assert_eq!(config.password.as_ref().map(|p| p.as_str()), Some("secret"));
1052        assert_eq!(config.params.get("key1"), Some(&"value1".to_string()));
1053        assert_eq!(config.connect_timeout, Some(Duration::from_secs(5)));
1054        assert_eq!(config.statement_timeout, Some(Duration::from_secs(60)));
1055        assert_eq!(config.application_name, Some("test_app".to_string()));
1056    }
1057
1058    #[test]
1059    fn test_connection_config_defaults() {
1060        let config = ConnectionConfig::new("db", "user");
1061
1062        assert!(config.connect_timeout.is_none());
1063        assert!(config.statement_timeout.is_none());
1064        assert!(config.keepalive_idle.is_none());
1065        assert!(config.application_name.is_none());
1066        assert!(config.extra_float_digits.is_none());
1067    }
1068
1069    // Verify that async functions return Send futures (compile-time check)
1070    // This ensures compatibility with async_trait and multi-threaded executors.
1071    // The actual assertion doesn't execute - it's type-checked at compile time.
1072    #[allow(dead_code)]
1073    const _SEND_SAFETY_CHECK: fn() = || {
1074        fn require_send<T: Send>() {}
1075
1076        // Dummy values just for type checking - never executed
1077        #[allow(unreachable_code)]
1078        let _ = || {
1079            // These would be checked at compile time if instantiated
1080            require_send::<
1081                std::pin::Pin<std::boxed::Box<dyn std::future::Future<Output = ()> + Send>>,
1082            >();
1083        };
1084    };
1085}