Skip to main content

fraiseql_wire/connection/
conn.rs

1//! Core connection type
2
3use super::state::ConnectionState;
4use super::tls::SslMode;
5use super::transport::Transport;
6use crate::auth::scram::ChannelBinding;
7use crate::auth::ScramClient;
8use crate::protocol::{
9    decode_message, encode_message, AuthenticationMessage, BackendMessage, FrontendMessage,
10};
11use crate::{Error, Result};
12use bytes::{Buf, BytesMut};
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::Duration;
16use tracing::Instrument;
17
18// Global counter for chunk metrics sampling (1 per 10 chunks)
19// Used to reduce per-chunk metric recording overhead
20static CHUNK_COUNT: AtomicU64 = AtomicU64::new(0);
21
22/// Connection configuration
23///
24/// Stores connection parameters including database, credentials, and optional timeouts.
25/// Use `ConnectionConfig::builder()` for advanced configuration with timeouts and keepalive.
26#[derive(Debug, Clone)]
27pub struct ConnectionConfig {
28    /// Database name
29    pub database: String,
30    /// Username
31    pub user: String,
32    /// Password (optional)
33    pub password: Option<String>,
34    /// Additional connection parameters
35    pub params: HashMap<String, String>,
36    /// TCP connection timeout (default: 10 seconds)
37    pub connect_timeout: Option<Duration>,
38    /// Query statement timeout
39    pub statement_timeout: Option<Duration>,
40    /// TCP keepalive idle interval (default: 5 minutes)
41    pub keepalive_idle: Option<Duration>,
42    /// Application name for Postgres logs (default: "fraiseql-wire")
43    pub application_name: Option<String>,
44    /// Postgres extra_float_digits setting
45    pub extra_float_digits: Option<i32>,
46    /// SSL/TLS mode
47    pub sslmode: SslMode,
48}
49
50impl ConnectionConfig {
51    /// Create new configuration with defaults
52    ///
53    /// # Arguments
54    ///
55    /// * `database` - Database name
56    /// * `user` - Username
57    ///
58    /// # Defaults
59    ///
60    /// - `connect_timeout`: None
61    /// - `statement_timeout`: None
62    /// - `keepalive_idle`: None
63    /// - `application_name`: None
64    /// - `extra_float_digits`: None
65    ///
66    /// For configured timeouts and keepalive, use `builder()` instead.
67    pub fn new(database: impl Into<String>, user: impl Into<String>) -> Self {
68        Self {
69            database: database.into(),
70            user: user.into(),
71            password: None,
72            params: HashMap::new(),
73            connect_timeout: None,
74            statement_timeout: None,
75            keepalive_idle: None,
76            application_name: None,
77            extra_float_digits: None,
78            sslmode: SslMode::default(),
79        }
80    }
81
82    /// Create a builder for advanced configuration
83    ///
84    /// Use this to configure timeouts, keepalive, and application name.
85    ///
86    /// # Examples
87    ///
88    /// ```ignore
89    /// let config = ConnectionConfig::builder("mydb", "user")
90    ///     .connect_timeout(Duration::from_secs(10))
91    ///     .statement_timeout(Duration::from_secs(30))
92    ///     .build();
93    /// ```
94    pub fn builder(
95        database: impl Into<String>,
96        user: impl Into<String>,
97    ) -> ConnectionConfigBuilder {
98        ConnectionConfigBuilder {
99            database: database.into(),
100            user: user.into(),
101            password: None,
102            params: HashMap::new(),
103            connect_timeout: None,
104            statement_timeout: None,
105            keepalive_idle: None,
106            application_name: None,
107            extra_float_digits: None,
108            sslmode: SslMode::default(),
109        }
110    }
111
112    /// Set password
113    pub fn password(mut self, password: impl Into<String>) -> Self {
114        self.password = Some(password.into());
115        self
116    }
117
118    /// Add connection parameter
119    pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
120        self.params.insert(key.into(), value.into());
121        self
122    }
123}
124
125/// Builder for creating `ConnectionConfig` with advanced options
126///
127/// Provides a fluent API for configuring timeouts, keepalive, and application name.
128///
129/// # Examples
130///
131/// ```ignore
132/// let config = ConnectionConfig::builder("mydb", "user")
133///     .password("secret")
134///     .connect_timeout(Duration::from_secs(10))
135///     .statement_timeout(Duration::from_secs(30))
136///     .keepalive_idle(Duration::from_secs(300))
137///     .application_name("my_app")
138///     .build();
139/// ```
140#[derive(Debug, Clone)]
141pub struct ConnectionConfigBuilder {
142    database: String,
143    user: String,
144    password: Option<String>,
145    params: HashMap<String, String>,
146    connect_timeout: Option<Duration>,
147    statement_timeout: Option<Duration>,
148    keepalive_idle: Option<Duration>,
149    application_name: Option<String>,
150    extra_float_digits: Option<i32>,
151    sslmode: SslMode,
152}
153
154impl ConnectionConfigBuilder {
155    /// Set the password
156    pub fn password(mut self, password: impl Into<String>) -> Self {
157        self.password = Some(password.into());
158        self
159    }
160
161    /// Add a connection parameter
162    pub fn param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
163        self.params.insert(key.into(), value.into());
164        self
165    }
166
167    /// Set TCP connection timeout
168    ///
169    /// Default: None (no timeout)
170    ///
171    /// # Arguments
172    ///
173    /// * `duration` - Timeout duration for establishing TCP connection
174    pub fn connect_timeout(mut self, duration: Duration) -> Self {
175        self.connect_timeout = Some(duration);
176        self
177    }
178
179    /// Set statement (query) timeout
180    ///
181    /// Default: None (unlimited)
182    ///
183    /// # Arguments
184    ///
185    /// * `duration` - Timeout duration for query execution
186    pub fn statement_timeout(mut self, duration: Duration) -> Self {
187        self.statement_timeout = Some(duration);
188        self
189    }
190
191    /// Set TCP keepalive idle interval
192    ///
193    /// Default: None (OS default)
194    ///
195    /// # Arguments
196    ///
197    /// * `duration` - Idle duration before sending keepalive probes
198    pub fn keepalive_idle(mut self, duration: Duration) -> Self {
199        self.keepalive_idle = Some(duration);
200        self
201    }
202
203    /// Set application name for Postgres logs
204    ///
205    /// Default: None (Postgres will not set application_name)
206    ///
207    /// # Arguments
208    ///
209    /// * `name` - Application name to identify in Postgres logs
210    pub fn application_name(mut self, name: impl Into<String>) -> Self {
211        self.application_name = Some(name.into());
212        self
213    }
214
215    /// Set extra_float_digits for float precision
216    ///
217    /// Default: None (use Postgres default)
218    ///
219    /// # Arguments
220    ///
221    /// * `digits` - Number of extra digits (typically 0-2)
222    pub fn extra_float_digits(mut self, digits: i32) -> Self {
223        self.extra_float_digits = Some(digits);
224        self
225    }
226
227    /// Set SSL/TLS mode
228    pub fn sslmode(mut self, mode: SslMode) -> Self {
229        self.sslmode = mode;
230        self
231    }
232
233    /// Build the configuration
234    pub fn build(self) -> ConnectionConfig {
235        ConnectionConfig {
236            database: self.database,
237            user: self.user,
238            password: self.password,
239            params: self.params,
240            connect_timeout: self.connect_timeout,
241            statement_timeout: self.statement_timeout,
242            keepalive_idle: self.keepalive_idle,
243            application_name: self.application_name,
244            extra_float_digits: self.extra_float_digits,
245            sslmode: self.sslmode,
246        }
247    }
248}
249
250/// Postgres connection
251pub struct Connection {
252    transport: Option<Transport>,
253    state: ConnectionState,
254    read_buf: BytesMut,
255    process_id: Option<i32>,
256    secret_key: Option<i32>,
257}
258
259impl Connection {
260    /// Create connection from transport
261    pub fn new(transport: Transport) -> Self {
262        Self {
263            transport: Some(transport),
264            state: ConnectionState::Initial,
265            read_buf: BytesMut::with_capacity(8192),
266            process_id: None,
267            secret_key: None,
268        }
269    }
270
271    /// Get current connection state
272    pub fn state(&self) -> ConnectionState {
273        self.state
274    }
275
276    /// Negotiate TLS upgrade with the server via the SSLRequest protocol.
277    ///
278    /// Sends the 8-byte SSLRequest message and reads the server's single-byte response.
279    /// If the server responds with `S`, the transport is upgraded to TLS.
280    /// If the server responds with `N`, behavior depends on `sslmode`.
281    async fn negotiate_tls(
282        &mut self,
283        tls_config: &super::TlsConfig,
284        hostname: &str,
285        sslmode: SslMode,
286    ) -> Result<()> {
287        self.state.transition(ConnectionState::NegotiatingTls)?;
288
289        // Send SSLRequest
290        let ssl_request = FrontendMessage::SslRequest;
291        self.send_message(&ssl_request).await?;
292
293        // Read single-byte response (S = proceed with TLS, N = reject)
294        let transport = self
295            .transport
296            .as_mut()
297            .expect("transport taken during TLS upgrade");
298        let n = transport.read_buf(&mut self.read_buf).await?;
299        if n == 0 {
300            return Err(Error::ConnectionClosed);
301        }
302
303        let response = self.read_buf[0];
304        self.read_buf.advance(1);
305
306        match response {
307            b'S' => {
308                tracing::debug!("server accepted TLS, upgrading connection");
309                // Take transport out, upgrade to TLS, put it back
310                let transport = self.transport.take().expect("transport not available");
311                self.transport = Some(transport.upgrade_to_tls(tls_config, hostname).await?);
312                tracing::info!("TLS connection established");
313                Ok(())
314            }
315            b'N' => {
316                tracing::debug!("server rejected TLS");
317                Err(Error::Config(format!(
318                    "server does not support TLS (sslmode={})",
319                    sslmode
320                )))
321            }
322            other => Err(Error::Protocol(format!(
323                "unexpected SSLRequest response byte: 0x{:02X}",
324                other
325            ))),
326        }
327    }
328
329    /// Perform startup and authentication
330    pub async fn startup(
331        &mut self,
332        config: &ConnectionConfig,
333        tls_config: Option<&super::TlsConfig>,
334        hostname: Option<&str>,
335    ) -> Result<()> {
336        async {
337            // TLS negotiation (if requested)
338            if config.sslmode != SslMode::Disable {
339                let tls = tls_config.ok_or_else(|| {
340                    Error::Config(format!(
341                        "sslmode={} requires TlsConfig but none was provided",
342                        config.sslmode
343                    ))
344                })?;
345                let host = hostname
346                    .ok_or_else(|| Error::Config("TLS negotiation requires a hostname".into()))?;
347                self.negotiate_tls(tls, host, config.sslmode).await?;
348            }
349
350            self.state.transition(ConnectionState::AwaitingAuth)?;
351
352            // Build startup parameters
353            let mut params = vec![
354                ("user".to_string(), config.user.clone()),
355                ("database".to_string(), config.database.clone()),
356            ];
357
358            // Add configured application name if specified
359            if let Some(app_name) = &config.application_name {
360                params.push(("application_name".to_string(), app_name.clone()));
361            }
362
363            // Add statement timeout if specified (in milliseconds)
364            if let Some(timeout) = config.statement_timeout {
365                params.push((
366                    "statement_timeout".to_string(),
367                    timeout.as_millis().to_string(),
368                ));
369            }
370
371            // Add extra_float_digits if specified
372            if let Some(digits) = config.extra_float_digits {
373                params.push(("extra_float_digits".to_string(), digits.to_string()));
374            }
375
376            // Add user-provided parameters
377            for (k, v) in &config.params {
378                params.push((k.clone(), v.clone()));
379            }
380
381            // Send startup message
382            let startup = FrontendMessage::Startup {
383                version: crate::protocol::constants::PROTOCOL_VERSION,
384                params,
385            };
386            self.send_message(&startup).await?;
387
388            // Authentication loop
389            self.state.transition(ConnectionState::Authenticating)?;
390            self.authenticate(config).await?;
391
392            self.state.transition(ConnectionState::Idle)?;
393            tracing::info!("startup complete");
394            Ok(())
395        }
396        .instrument(tracing::info_span!(
397            "startup",
398            user = %config.user,
399            database = %config.database
400        ))
401        .await
402    }
403
404    /// Handle authentication
405    async fn authenticate(&mut self, config: &ConnectionConfig) -> Result<()> {
406        let auth_start = std::time::Instant::now();
407        let mut auth_mechanism = "unknown";
408
409        loop {
410            let msg = self.receive_message().await?;
411
412            match msg {
413                BackendMessage::Authentication(auth) => match auth {
414                    AuthenticationMessage::Ok => {
415                        tracing::debug!("authentication successful");
416                        crate::metrics::counters::auth_successful(auth_mechanism);
417                        crate::metrics::histograms::auth_duration(
418                            auth_mechanism,
419                            auth_start.elapsed().as_millis() as u64,
420                        );
421                        // Don't break here! Must continue reading until ReadyForQuery
422                    }
423                    AuthenticationMessage::CleartextPassword => {
424                        auth_mechanism = crate::metrics::labels::MECHANISM_CLEARTEXT;
425                        crate::metrics::counters::auth_attempted(auth_mechanism);
426
427                        let password = config
428                            .password
429                            .as_ref()
430                            .ok_or_else(|| Error::Authentication("password required".into()))?;
431                        let pwd_msg = FrontendMessage::Password(password.clone());
432                        self.send_message(&pwd_msg).await?;
433                    }
434                    AuthenticationMessage::Md5Password { .. } => {
435                        return Err(Error::Authentication(
436                            "MD5 authentication not supported. Use SCRAM-SHA-256 or cleartext password".into(),
437                        ));
438                    }
439                    AuthenticationMessage::Sasl { mechanisms } => {
440                        auth_mechanism = crate::metrics::labels::MECHANISM_SCRAM;
441                        crate::metrics::counters::auth_attempted(auth_mechanism);
442                        self.handle_sasl(&mechanisms, config).await?;
443                    }
444                    AuthenticationMessage::SaslContinue { .. } => {
445                        return Err(Error::Protocol(
446                            "unexpected SaslContinue outside of SASL flow".into(),
447                        ));
448                    }
449                    AuthenticationMessage::SaslFinal { .. } => {
450                        return Err(Error::Protocol(
451                            "unexpected SaslFinal outside of SASL flow".into(),
452                        ));
453                    }
454                },
455                BackendMessage::BackendKeyData {
456                    process_id,
457                    secret_key,
458                } => {
459                    self.process_id = Some(process_id);
460                    self.secret_key = Some(secret_key);
461                }
462                BackendMessage::ParameterStatus { name, value } => {
463                    tracing::debug!("parameter status: {} = {}", name, value);
464                }
465                BackendMessage::ReadyForQuery { status: _ } => {
466                    break;
467                }
468                BackendMessage::ErrorResponse(err) => {
469                    crate::metrics::counters::auth_failed(auth_mechanism, "server_error");
470                    return Err(Error::Authentication(err.to_string()));
471                }
472                _ => {
473                    return Err(Error::Protocol(format!(
474                        "unexpected message during auth: {:?}",
475                        msg
476                    )));
477                }
478            }
479        }
480
481        Ok(())
482    }
483
484    /// Handle SASL authentication (SCRAM-SHA-256)
485    async fn handle_sasl(
486        &mut self,
487        mechanisms: &[String],
488        config: &ConnectionConfig,
489    ) -> Result<()> {
490        // Determine channel binding and mechanism
491        let channel_binding_data = self
492            .transport
493            .as_ref()
494            .and_then(|t| t.channel_binding_data());
495
496        let (mechanism, channel_binding) = if mechanisms.contains(&"SCRAM-SHA-256-PLUS".to_string())
497        {
498            if let Some(cb_data) = channel_binding_data {
499                (
500                    "SCRAM-SHA-256-PLUS",
501                    ChannelBinding::TlsServerEndPoint(cb_data),
502                )
503            } else {
504                ("SCRAM-SHA-256", ChannelBinding::None)
505            }
506        } else if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
507            ("SCRAM-SHA-256", ChannelBinding::None)
508        } else {
509            return Err(Error::Authentication(format!(
510                "server does not support SCRAM-SHA-256. Available: {}",
511                mechanisms.join(", ")
512            )));
513        };
514
515        // Get password
516        let password = config.password.as_ref().ok_or_else(|| {
517            Error::Authentication("password required for SCRAM authentication".into())
518        })?;
519
520        // Create SCRAM client with channel binding support
521        let mut scram = ScramClient::with_channel_binding(
522            config.user.clone(),
523            password.clone(),
524            channel_binding,
525        );
526        tracing::debug!("initiating {} authentication", mechanism);
527
528        // Send SaslInitialResponse with client first message
529        let client_first = scram.client_first();
530        let msg = FrontendMessage::SaslInitialResponse {
531            mechanism: mechanism.to_string(),
532            data: client_first.into_bytes(),
533        };
534        self.send_message(&msg).await?;
535
536        // Receive SaslContinue with server first message
537        let server_first_msg = self.receive_message().await?;
538        let server_first_data = match server_first_msg {
539            BackendMessage::Authentication(AuthenticationMessage::SaslContinue { data }) => data,
540            BackendMessage::ErrorResponse(err) => {
541                return Err(Error::Authentication(format!("SASL server error: {}", err)));
542            }
543            _ => {
544                return Err(Error::Protocol(
545                    "expected SaslContinue message during SASL authentication".into(),
546                ));
547            }
548        };
549
550        let server_first = String::from_utf8(server_first_data).map_err(|e| {
551            Error::Authentication(format!("invalid UTF-8 in server first message: {}", e))
552        })?;
553
554        tracing::debug!("received SCRAM server first message");
555
556        // Generate client final message
557        let (client_final, scram_state) = scram
558            .client_final(&server_first)
559            .map_err(|e| Error::Authentication(format!("SCRAM error: {}", e)))?;
560
561        // Send SaslResponse with client final message
562        let msg = FrontendMessage::SaslResponse {
563            data: client_final.into_bytes(),
564        };
565        self.send_message(&msg).await?;
566
567        // Receive SaslFinal with server verification
568        let server_final_msg = self.receive_message().await?;
569        let server_final_data = match server_final_msg {
570            BackendMessage::Authentication(AuthenticationMessage::SaslFinal { data }) => data,
571            BackendMessage::ErrorResponse(err) => {
572                return Err(Error::Authentication(format!("SASL server error: {}", err)));
573            }
574            _ => {
575                return Err(Error::Protocol(
576                    "expected SaslFinal message during SASL authentication".into(),
577                ));
578            }
579        };
580
581        let server_final = String::from_utf8(server_final_data).map_err(|e| {
582            Error::Authentication(format!("invalid UTF-8 in server final message: {}", e))
583        })?;
584
585        // Verify server signature
586        scram
587            .verify_server_final(&server_final, &scram_state)
588            .map_err(|e| Error::Authentication(format!("SCRAM verification failed: {}", e)))?;
589
590        tracing::debug!("SCRAM-SHA-256 authentication successful");
591        Ok(())
592    }
593
594    /// Execute a simple query (returns all backend messages)
595    pub async fn simple_query(&mut self, query: &str) -> Result<Vec<BackendMessage>> {
596        if self.state != ConnectionState::Idle {
597            return Err(Error::ConnectionBusy(format!(
598                "connection in state: {}",
599                self.state
600            )));
601        }
602
603        self.state.transition(ConnectionState::QueryInProgress)?;
604
605        let query_msg = FrontendMessage::Query(query.to_string());
606        self.send_message(&query_msg).await?;
607
608        self.state.transition(ConnectionState::ReadingResults)?;
609
610        let mut messages = Vec::new();
611
612        loop {
613            let msg = self.receive_message().await?;
614            let is_ready = matches!(msg, BackendMessage::ReadyForQuery { .. });
615            messages.push(msg);
616
617            if is_ready {
618                break;
619            }
620        }
621
622        self.state.transition(ConnectionState::Idle)?;
623        Ok(messages)
624    }
625
626    /// Send a frontend message
627    async fn send_message(&mut self, msg: &FrontendMessage) -> Result<()> {
628        let buf = encode_message(msg)?;
629        let transport = self.transport.as_mut().expect("transport not available");
630        transport.write_all(&buf).await?;
631        transport.flush().await?;
632        Ok(())
633    }
634
635    /// Receive a backend message
636    async fn receive_message(&mut self) -> Result<BackendMessage> {
637        loop {
638            // Try to decode a message from buffer (without cloning!)
639            if let Ok((msg, consumed)) = decode_message(&mut self.read_buf) {
640                self.read_buf.advance(consumed);
641                return Ok(msg);
642            }
643
644            // Need more data
645            let transport = self.transport.as_mut().expect("transport not available");
646            let n = transport.read_buf(&mut self.read_buf).await?;
647            if n == 0 {
648                return Err(Error::ConnectionClosed);
649            }
650        }
651    }
652
653    /// Close the connection
654    pub async fn close(mut self) -> Result<()> {
655        self.state.transition(ConnectionState::Closed)?;
656        let _ = self.send_message(&FrontendMessage::Terminate).await;
657        let transport = self.transport.as_mut().expect("transport not available");
658        transport.shutdown().await?;
659        Ok(())
660    }
661
662    /// Execute a streaming query
663    ///
664    /// Note: This method consumes the connection. The stream maintains the connection
665    /// internally. Once the stream is exhausted or dropped, the connection is closed.
666    #[allow(clippy::too_many_arguments)]
667    pub async fn streaming_query(
668        mut self,
669        query: &str,
670        chunk_size: usize,
671        max_memory: Option<usize>,
672        soft_limit_warn_threshold: Option<f32>,
673        soft_limit_fail_threshold: Option<f32>,
674        enable_adaptive_chunking: bool,
675        adaptive_min_chunk_size: Option<usize>,
676        adaptive_max_chunk_size: Option<usize>,
677    ) -> Result<crate::stream::JsonStream> {
678        async {
679            let startup_start = std::time::Instant::now();
680
681            use crate::json::validate_row_description;
682            use crate::stream::{extract_json_bytes, parse_json, AdaptiveChunking, ChunkingStrategy, JsonStream};
683            use serde_json::Value;
684            use tokio::sync::mpsc;
685
686            if self.state != ConnectionState::Idle {
687                return Err(Error::ConnectionBusy(format!(
688                    "connection in state: {}",
689                    self.state
690                )));
691            }
692
693            self.state.transition(ConnectionState::QueryInProgress)?;
694
695            let query_msg = FrontendMessage::Query(query.to_string());
696            self.send_message(&query_msg).await?;
697
698            self.state.transition(ConnectionState::ReadingResults)?;
699
700            // Read RowDescription, but handle other messages that may come first
701            // (e.g., ParameterStatus, BackendKeyData, ErrorResponse, NoticeResponse)
702            let row_desc;
703            loop {
704                let msg = self.receive_message().await?;
705
706                match msg {
707                    BackendMessage::ErrorResponse(err) => {
708                        // Query failed - consume ReadyForQuery and return error
709                        tracing::debug!("PostgreSQL error response: {}", err);
710                        loop {
711                            let msg = self.receive_message().await?;
712                            if matches!(msg, BackendMessage::ReadyForQuery { .. }) {
713                                break;
714                            }
715                        }
716                        return Err(Error::Sql(err.to_string()));
717                    }
718                    BackendMessage::BackendKeyData { process_id, secret_key: _ } => {
719                        // This provides the key needed for cancel requests - store it and continue
720                        tracing::debug!("PostgreSQL backend key data received: pid={}", process_id);
721                        // Note: We would store this if we need to support cancellation
722                        continue;
723                    }
724                    BackendMessage::ParameterStatus { .. } => {
725                        // Parameter status changes are informational - skip them
726                        tracing::debug!("PostgreSQL parameter status change received");
727                        continue;
728                    }
729                    BackendMessage::NoticeResponse(notice) => {
730                        // Notices are non-fatal warnings - skip them
731                        tracing::debug!("PostgreSQL notice: {}", notice);
732                        continue;
733                    }
734                    BackendMessage::RowDescription(_) => {
735                        row_desc = msg;
736                        break;
737                    }
738                    BackendMessage::ReadyForQuery { .. } => {
739                        // Received ReadyForQuery without RowDescription
740                        // This means the query didn't produce a result set
741                        return Err(Error::Protocol(
742                            "no result set received from query - \
743                             check that the entity name is correct and the table/view exists"
744                                .into(),
745                        ));
746                    }
747                    _ => {
748                        return Err(Error::Protocol(format!(
749                            "unexpected message type in query response: {:?}",
750                            msg
751                        )));
752                    }
753                }
754            }
755
756            validate_row_description(&row_desc)?;
757
758            // Record startup timing
759            let startup_duration = startup_start.elapsed().as_millis() as u64;
760            let entity = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
761            crate::metrics::histograms::query_startup_duration(&entity, startup_duration);
762
763            // Create channels
764            let (result_tx, result_rx) = mpsc::channel::<Result<Value>>(chunk_size);
765            let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
766
767            // Create stream instance first so we can clone its pause/resume signals
768            let entity_for_metrics = extract_entity_from_query(query).unwrap_or_else(|| "unknown".to_string());
769            let entity_for_stream = entity_for_metrics.clone();  // Clone for stream
770
771            let stream = JsonStream::new(
772                result_rx,
773                cancel_tx,
774                entity_for_stream,
775                max_memory,
776                soft_limit_warn_threshold,
777                soft_limit_fail_threshold,
778            );
779
780            // Clone pause/resume signals for background task (only if pause/resume is initialized)
781            let state_lock = stream.clone_state();
782            let pause_signal = stream.clone_pause_signal();
783            let resume_signal = stream.clone_resume_signal();
784
785            // Clone atomic state for fast state checks in background task
786            let state_atomic = stream.clone_state_atomic();
787
788            // Clone pause timeout for background task
789            let pause_timeout = stream.pause_timeout();
790
791            // Spawn background task to read rows
792            let query_start = std::time::Instant::now();
793
794            tokio::spawn(async move {
795                let strategy = ChunkingStrategy::new(chunk_size);
796                let mut chunk = strategy.new_chunk();
797                let mut total_rows = 0u64;
798
799            // Initialize adaptive chunking if enabled
800            let _adaptive = if enable_adaptive_chunking {
801                let mut adp = AdaptiveChunking::new();
802
803                // Apply custom bounds if provided
804                if let Some(min) = adaptive_min_chunk_size {
805                    if let Some(max) = adaptive_max_chunk_size {
806                        adp = adp.with_bounds(min, max);
807                    }
808                }
809
810                Some(adp)
811            } else {
812                None
813            };
814            let _current_chunk_size = chunk_size;
815
816            loop {
817                // Check lightweight atomic state first (fast path)
818                // Only check atomic if pause/resume infrastructure is actually initialized
819                if state_lock.is_some() && state_atomic.load(std::sync::atomic::Ordering::Acquire) == 1 {
820                    // Paused state detected via atomic, now handle with Mutex
821                    if let (Some(ref state_lock), Some(ref _pause_signal), Some(ref resume_signal)) =
822                        (&state_lock, &pause_signal, &resume_signal)
823                    {
824                        let current_state = state_lock.lock().await;
825                        if *current_state == crate::stream::StreamState::Paused {
826                            tracing::debug!("stream paused, waiting for resume");
827                            drop(current_state); // Release lock before waiting
828
829                            // Wait with optional timeout
830                            if let Some(timeout) = pause_timeout {
831                                match tokio::time::timeout(timeout, resume_signal.notified()).await {
832                                    Ok(_) => {
833                                        tracing::debug!("stream resumed");
834                                    }
835                                    Err(_) => {
836                                        tracing::debug!("pause timeout expired, auto-resuming");
837                                        crate::metrics::counters::stream_pause_timeout_expired(&entity_for_metrics);
838                                    }
839                                }
840                            } else {
841                                // No timeout, wait indefinitely
842                                resume_signal.notified().await;
843                                tracing::debug!("stream resumed");
844                            }
845
846                            // Update state back to Running
847                            let mut state = state_lock.lock().await;
848                            *state = crate::stream::StreamState::Running;
849                        }
850                    }
851                }
852
853                tokio::select! {
854                    // Check for cancellation
855                    _ = cancel_rx.recv() => {
856                        tracing::debug!("query cancelled");
857                        crate::metrics::counters::query_completed("cancelled", &entity_for_metrics);
858                        break;
859                    }
860
861                    // Read next message
862                    msg_result = self.receive_message() => {
863                        match msg_result {
864                            Ok(msg) => match msg {
865                                BackendMessage::DataRow(_) => {
866                                    match extract_json_bytes(&msg) {
867                                        Ok(json_bytes) => {
868                                            chunk.push(json_bytes);
869
870                                            if strategy.is_full(&chunk) {
871                                                let chunk_start = std::time::Instant::now();
872                                                let rows = chunk.into_rows();
873                                                let chunk_size_rows = rows.len() as u64;
874
875                                                // Batch JSON parsing and sending to reduce lock contention
876                                                // Send 8 values per channel send instead of 1 (8x fewer locks)
877                                                const BATCH_SIZE: usize = 8;
878                                                let mut batch = Vec::with_capacity(BATCH_SIZE);
879                                                let mut send_error = false;
880
881                                                for row_bytes in rows {
882                                                    match parse_json(row_bytes) {
883                                                        Ok(value) => {
884                                                            total_rows += 1;
885                                                            batch.push(Ok(value));
886
887                                                            // Send batch when full
888                                                            if batch.len() == BATCH_SIZE {
889                                                                for item in batch.drain(..) {
890                                                                    if result_tx.send(item).await.is_err() {
891                                                                        crate::metrics::counters::query_completed("error", &entity_for_metrics);
892                                                                        send_error = true;
893                                                                        break;
894                                                                    }
895                                                                }
896                                                                if send_error {
897                                                                    break;
898                                                                }
899                                                            }
900                                                        }
901                                                        Err(e) => {
902                                                            crate::metrics::counters::json_parse_error(&entity_for_metrics);
903                                                            let _ = result_tx.send(Err(e)).await;
904                                                            crate::metrics::counters::query_completed("error", &entity_for_metrics);
905                                                            send_error = true;
906                                                            break;
907                                                        }
908                                                    }
909                                                }
910
911                                                // Send remaining batch items
912                                                if !send_error {
913                                                    for item in batch {
914                                                        if result_tx.send(item).await.is_err() {
915                                                            crate::metrics::counters::query_completed("error", &entity_for_metrics);
916                                                            break;
917                                                        }
918                                                    }
919                                                }
920
921                                                // Record chunk metrics (sampled, not per-chunk)
922                                                let chunk_duration = chunk_start.elapsed().as_millis() as u64;
923
924                                                // Only record metrics every 10 chunks to reduce overhead
925                                                let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
926                                                if chunk_idx % 10 == 0 {
927                                                    crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
928                                                    crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
929                                                }
930
931                                                // Adaptive chunking: disabled by default for better performance
932                                                // Enable only if explicitly requested via enable_adaptive_chunking parameter
933                                                // Note: adaptive adjustment adds ~0.5-1% overhead per chunk
934                                                // For fixed chunk sizes (default), skip this entirely
935
936                                                chunk = strategy.new_chunk();
937                                            }
938                                        }
939                                        Err(e) => {
940                                            crate::metrics::counters::json_parse_error(&entity_for_metrics);
941                                            let _ = result_tx.send(Err(e)).await;
942                                            crate::metrics::counters::query_completed("error", &entity_for_metrics);
943                                            break;
944                                        }
945                                    }
946                                }
947                                BackendMessage::CommandComplete(_) => {
948                                    // Send remaining chunk
949                                    if !chunk.is_empty() {
950                                        let chunk_start = std::time::Instant::now();
951                                        let rows = chunk.into_rows();
952                                        let chunk_size_rows = rows.len() as u64;
953
954                                        // Batch JSON parsing and sending to reduce lock contention
955                                        const BATCH_SIZE: usize = 8;
956                                        let mut batch = Vec::with_capacity(BATCH_SIZE);
957                                        let mut send_error = false;
958
959                                        for row_bytes in rows {
960                                            match parse_json(row_bytes) {
961                                                Ok(value) => {
962                                                    total_rows += 1;
963                                                    batch.push(Ok(value));
964
965                                                    // Send batch when full
966                                                    if batch.len() == BATCH_SIZE {
967                                                        for item in batch.drain(..) {
968                                                            if result_tx.send(item).await.is_err() {
969                                                                crate::metrics::counters::query_completed("error", &entity_for_metrics);
970                                                                send_error = true;
971                                                                break;
972                                                            }
973                                                        }
974                                                        if send_error {
975                                                            break;
976                                                        }
977                                                    }
978                                                }
979                                                Err(e) => {
980                                                    crate::metrics::counters::json_parse_error(&entity_for_metrics);
981                                                    let _ = result_tx.send(Err(e)).await;
982                                                    crate::metrics::counters::query_completed("error", &entity_for_metrics);
983                                                    send_error = true;
984                                                    break;
985                                                }
986                                            }
987                                        }
988
989                                        // Send remaining batch items
990                                        if !send_error {
991                                            for item in batch {
992                                                if result_tx.send(item).await.is_err() {
993                                                    crate::metrics::counters::query_completed("error", &entity_for_metrics);
994                                                    break;
995                                                }
996                                            }
997                                        }
998
999                                        // Record final chunk metrics (sampled)
1000                                        let chunk_duration = chunk_start.elapsed().as_millis() as u64;
1001                                        let chunk_idx = CHUNK_COUNT.fetch_add(1, Ordering::Relaxed);
1002                                        if chunk_idx % 10 == 0 {
1003                                            crate::metrics::histograms::chunk_processing_duration(&entity_for_metrics, chunk_duration);
1004                                            crate::metrics::histograms::chunk_size(&entity_for_metrics, chunk_size_rows);
1005                                        }
1006                                        chunk = strategy.new_chunk();
1007                                    }
1008
1009                                    // Record query completion metrics
1010                                    let query_duration = query_start.elapsed().as_millis() as u64;
1011                                    crate::metrics::counters::rows_processed(&entity_for_metrics, total_rows, "ok");
1012                                    crate::metrics::histograms::query_total_duration(&entity_for_metrics, query_duration);
1013                                    crate::metrics::counters::query_completed("success", &entity_for_metrics);
1014                                }
1015                                BackendMessage::ReadyForQuery { .. } => {
1016                                    break;
1017                                }
1018                                BackendMessage::ErrorResponse(err) => {
1019                                    crate::metrics::counters::query_error(&entity_for_metrics, "server_error");
1020                                    crate::metrics::counters::query_completed("error", &entity_for_metrics);
1021                                    let _ = result_tx.send(Err(Error::Sql(err.to_string()))).await;
1022                                    break;
1023                                }
1024                                _ => {
1025                                    crate::metrics::counters::query_error(&entity_for_metrics, "protocol_error");
1026                                    crate::metrics::counters::query_completed("error", &entity_for_metrics);
1027                                    let _ = result_tx.send(Err(Error::Protocol(
1028                                        format!("unexpected message: {:?}", msg)
1029                                    ))).await;
1030                                    break;
1031                                }
1032                            },
1033                            Err(e) => {
1034                                crate::metrics::counters::query_error(&entity_for_metrics, "connection_error");
1035                                crate::metrics::counters::query_completed("error", &entity_for_metrics);
1036                                let _ = result_tx.send(Err(e)).await;
1037                                break;
1038                            }
1039                        }
1040                    }
1041                }
1042            }
1043            });
1044
1045            Ok(stream)
1046        }
1047        .instrument(tracing::debug_span!(
1048            "streaming_query",
1049            query = %query,
1050            chunk_size = %chunk_size
1051        ))
1052        .await
1053    }
1054}
1055
1056/// Extract entity name from query for metrics
1057/// Query format: SELECT data FROM v_{entity} ...
1058fn extract_entity_from_query(query: &str) -> Option<String> {
1059    let query_lower = query.to_lowercase();
1060    if let Some(from_pos) = query_lower.find("from") {
1061        let after_from = &query_lower[from_pos + 4..].trim_start();
1062        if let Some(entity_start) = after_from.find('v').or_else(|| after_from.find('t')) {
1063            let potential_table = &after_from[entity_start..];
1064            // Extract table name: "v_entity" or "tv_entity"
1065            let end_pos = potential_table
1066                .find(' ')
1067                .or_else(|| potential_table.find(';'))
1068                .unwrap_or(potential_table.len());
1069            let table_name = &potential_table[..end_pos];
1070            // Extract entity from table name
1071            if let Some(entity_pos) = table_name.rfind('_') {
1072                return Some(table_name[entity_pos + 1..].to_string());
1073            }
1074        }
1075    }
1076    None
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081    use super::*;
1082
1083    #[test]
1084    fn test_connection_config() {
1085        let config = ConnectionConfig::new("testdb", "testuser")
1086            .password("testpass")
1087            .param("application_name", "fraiseql-wire");
1088
1089        assert_eq!(config.database, "testdb");
1090        assert_eq!(config.user, "testuser");
1091        assert_eq!(config.password, Some("testpass".to_string()));
1092        assert_eq!(
1093            config.params.get("application_name"),
1094            Some(&"fraiseql-wire".to_string())
1095        );
1096    }
1097
1098    #[test]
1099    fn test_connection_config_builder_basic() {
1100        let config = ConnectionConfig::builder("mydb", "myuser")
1101            .password("mypass")
1102            .build();
1103
1104        assert_eq!(config.database, "mydb");
1105        assert_eq!(config.user, "myuser");
1106        assert_eq!(config.password, Some("mypass".to_string()));
1107        assert_eq!(config.connect_timeout, None);
1108        assert_eq!(config.statement_timeout, None);
1109        assert_eq!(config.keepalive_idle, None);
1110        assert_eq!(config.application_name, None);
1111    }
1112
1113    #[test]
1114    fn test_connection_config_builder_with_timeouts() {
1115        let connect_timeout = Duration::from_secs(10);
1116        let statement_timeout = Duration::from_secs(30);
1117        let keepalive_idle = Duration::from_secs(300);
1118
1119        let config = ConnectionConfig::builder("mydb", "myuser")
1120            .password("mypass")
1121            .connect_timeout(connect_timeout)
1122            .statement_timeout(statement_timeout)
1123            .keepalive_idle(keepalive_idle)
1124            .build();
1125
1126        assert_eq!(config.connect_timeout, Some(connect_timeout));
1127        assert_eq!(config.statement_timeout, Some(statement_timeout));
1128        assert_eq!(config.keepalive_idle, Some(keepalive_idle));
1129    }
1130
1131    #[test]
1132    fn test_connection_config_builder_with_application_name() {
1133        let config = ConnectionConfig::builder("mydb", "myuser")
1134            .application_name("my_app")
1135            .extra_float_digits(2)
1136            .build();
1137
1138        assert_eq!(config.application_name, Some("my_app".to_string()));
1139        assert_eq!(config.extra_float_digits, Some(2));
1140    }
1141
1142    #[test]
1143    fn test_connection_config_builder_fluent() {
1144        let config = ConnectionConfig::builder("mydb", "myuser")
1145            .password("secret")
1146            .param("key1", "value1")
1147            .connect_timeout(Duration::from_secs(5))
1148            .statement_timeout(Duration::from_secs(60))
1149            .application_name("test_app")
1150            .build();
1151
1152        assert_eq!(config.database, "mydb");
1153        assert_eq!(config.user, "myuser");
1154        assert_eq!(config.password, Some("secret".to_string()));
1155        assert_eq!(config.params.get("key1"), Some(&"value1".to_string()));
1156        assert_eq!(config.connect_timeout, Some(Duration::from_secs(5)));
1157        assert_eq!(config.statement_timeout, Some(Duration::from_secs(60)));
1158        assert_eq!(config.application_name, Some("test_app".to_string()));
1159    }
1160
1161    #[test]
1162    fn test_connection_config_defaults() {
1163        let config = ConnectionConfig::new("db", "user");
1164
1165        assert!(config.connect_timeout.is_none());
1166        assert!(config.statement_timeout.is_none());
1167        assert!(config.keepalive_idle.is_none());
1168        assert!(config.application_name.is_none());
1169        assert!(config.extra_float_digits.is_none());
1170        assert_eq!(config.sslmode, super::SslMode::Disable);
1171    }
1172
1173    #[test]
1174    fn test_connection_config_builder_with_sslmode() {
1175        let config = ConnectionConfig::builder("mydb", "myuser")
1176            .sslmode(super::SslMode::VerifyFull)
1177            .build();
1178
1179        assert_eq!(config.sslmode, super::SslMode::VerifyFull);
1180    }
1181
1182    // Verify that async functions return Send futures (compile-time check)
1183    // This ensures compatibility with async_trait and multi-threaded executors.
1184    // The actual assertion doesn't execute - it's type-checked at compile time.
1185    #[allow(dead_code)]
1186    const _SEND_SAFETY_CHECK: fn() = || {
1187        fn require_send<T: Send>() {}
1188
1189        // Dummy values just for type checking - never executed
1190        #[allow(unreachable_code)]
1191        let _ = || {
1192            // These would be checked at compile time if instantiated
1193            require_send::<
1194                std::pin::Pin<std::boxed::Box<dyn std::future::Future<Output = ()> + Send>>,
1195            >();
1196        };
1197    };
1198}