Skip to main content

sentinel_driver/connection/
client.rs

1use super::{
2    pipeline, startup, BackendMessage, BytesMut, CancelToken, Config, Connection, Duration, Error,
3    PgConnection, PipelineBatch, Result, StatementCache, ToSql, TransactionStatus,
4};
5use crate::config::{LoadBalanceHosts, TargetSessionAttrs};
6
7impl Connection {
8    /// Connect to PostgreSQL and perform the startup handshake.
9    ///
10    /// With multiple hosts configured, tries each host in order (or shuffled
11    /// if `load_balance_hosts=random`) until one succeeds and matches the
12    /// required `target_session_attrs`.
13    pub async fn connect(config: Config) -> Result<Self> {
14        let mut hosts: Vec<(String, u16)> = config.hosts().to_vec();
15
16        if hosts.is_empty() {
17            hosts.push(("localhost".to_string(), 5432));
18        }
19
20        if config.load_balance_hosts() == LoadBalanceHosts::Random {
21            use rand::seq::SliceRandom;
22            use rand::thread_rng;
23            hosts.shuffle(&mut thread_rng());
24        }
25
26        let mut last_error: Option<Error> = None;
27
28        for (host, port) in &hosts {
29            match Self::try_connect_host(&config, host, *port).await {
30                Ok(conn) => return Ok(conn),
31                Err(e) => {
32                    tracing::debug!(host = %host, port = %port, error = %e, "host failed");
33                    last_error = Some(e);
34                }
35            }
36        }
37
38        Err(last_error.unwrap_or_else(|| Error::AllHostsFailed("no hosts configured".to_string())))
39    }
40
41    /// Try connecting to a single host, performing startup and session attrs check.
42    async fn try_connect_host(config: &Config, host: &str, port: u16) -> Result<Self> {
43        let mut conn = PgConnection::connect_host(config, host, port).await?;
44        let result = startup::startup(&mut conn, config).await?;
45
46        // Check target_session_attrs after successful auth
47        if config.target_session_attrs() != TargetSessionAttrs::Any {
48            startup::check_session_attrs(&mut conn, config.target_session_attrs()).await?;
49        }
50
51        let query_timeout = config.statement_timeout();
52
53        Ok(Self {
54            conn,
55            config: config.clone(),
56            connected_host: host.to_string(),
57            connected_port: port,
58            process_id: result.process_id,
59            secret_key: result.secret_key,
60            transaction_status: result.transaction_status,
61            stmt_cache: StatementCache::new(),
62            query_timeout,
63            is_broken: false,
64            instrumentation: config
65                .instrumentation
66                .clone()
67                .unwrap_or_else(crate::instrumentation::noop),
68        })
69    }
70
71    /// Close the connection gracefully.
72    pub async fn close(self) -> Result<()> {
73        self.conn.close().await
74    }
75
76    /// Get a cancel token for this connection.
77    ///
78    /// The token can be cloned and sent to another task to cancel a
79    /// running query. See [`CancelToken`] for details.
80    pub fn cancel_token(&self) -> CancelToken {
81        CancelToken::new(
82            &self.connected_host,
83            self.connected_port,
84            self.process_id,
85            self.secret_key,
86        )
87    }
88
89    /// Returns `true` if the connection is using TLS.
90    /// Returns the configuration used for this connection.
91    pub fn config(&self) -> &Config {
92        &self.config
93    }
94
95    /// Returns the host this connection is connected to.
96    pub fn connected_host(&self) -> &str {
97        &self.connected_host
98    }
99
100    /// Returns the port this connection is connected to.
101    pub fn connected_port(&self) -> u16 {
102        self.connected_port
103    }
104
105    pub fn is_tls(&self) -> bool {
106        self.conn.is_tls()
107    }
108
109    /// Returns `true` if connected via Unix domain socket.
110    #[cfg(unix)]
111    pub fn is_unix(&self) -> bool {
112        self.conn.is_unix()
113    }
114
115    /// The server process ID for this connection.
116    pub fn process_id(&self) -> i32 {
117        self.process_id
118    }
119
120    /// Returns the configured query timeout, if any.
121    pub fn query_timeout(&self) -> Option<Duration> {
122        self.query_timeout
123    }
124
125    /// Returns `true` if the connection has been marked broken by a timeout.
126    ///
127    /// A broken connection should be discarded — the server state is
128    /// indeterminate after a cancelled query.
129    pub fn is_broken(&self) -> bool {
130        self.is_broken
131    }
132
133    /// Current transaction status.
134    pub fn transaction_status(&self) -> TransactionStatus {
135        self.transaction_status
136    }
137
138    /// Access the underlying PgConnection mutably.
139    pub(crate) fn pg_connection_mut(&mut self) -> &mut PgConnection {
140        &mut self.conn
141    }
142
143    // ── Internal ─────────────────────────────────────
144
145    pub(crate) async fn query_internal(
146        &mut self,
147        sql: &str,
148        params: &[&(dyn ToSql + Sync)],
149    ) -> Result<pipeline::QueryResult> {
150        self.instr().on_event(&crate::Event::ExecuteStart {
151            stmt: crate::StmtRef::Inline { sql },
152            param_count: params.len(),
153        });
154        let started = std::time::Instant::now();
155        let res = self.query_internal_inner(sql, params).await;
156        let duration = started.elapsed();
157        let (rows, outcome) = match &res {
158            Ok(pipeline::QueryResult::Rows(v)) => (v.len() as u64, crate::Outcome::Ok),
159            Ok(pipeline::QueryResult::Command(r)) => (r.rows_affected, crate::Outcome::Ok),
160            Err(e) => (0, crate::Outcome::Err(e)),
161        };
162        self.instr().on_event(&crate::Event::ExecuteFinish {
163            stmt: crate::StmtRef::Inline { sql },
164            rows,
165            duration,
166            outcome,
167        });
168        res
169    }
170
171    async fn query_internal_inner(
172        &mut self,
173        sql: &str,
174        params: &[&(dyn ToSql + Sync)],
175    ) -> Result<pipeline::QueryResult> {
176        // Encode parameters
177        let param_types: Vec<u32> = params.iter().map(|p| p.oid().0).collect();
178        let mut encoded_params: Vec<Option<Vec<u8>>> = Vec::with_capacity(params.len());
179
180        for param in params {
181            if param.is_null() {
182                encoded_params.push(None);
183            } else {
184                let mut buf = BytesMut::new();
185                param.to_sql(&mut buf)?;
186                encoded_params.push(Some(buf.to_vec()));
187            }
188        }
189
190        // Use pipeline for single query (same protocol, consistent code path)
191        let mut batch = PipelineBatch::new();
192        batch.add(sql.to_string(), param_types, encoded_params);
193
194        let mut results = batch.execute(&mut self.conn).await?;
195
196        results
197            .pop()
198            .ok_or_else(|| Error::protocol("pipeline returned no results"))
199    }
200
201    pub(crate) async fn drain_until_ready(&mut self) -> Result<()> {
202        loop {
203            if let BackendMessage::ReadyForQuery { transaction_status } = self.conn.recv().await? {
204                self.transaction_status = transaction_status;
205                return Ok(());
206            }
207        }
208    }
209
210    /// Install an `Instrumentation` impl on this connection.
211    /// Replaces any previous installation.
212    pub fn set_instrumentation(&mut self, instr: std::sync::Arc<dyn crate::Instrumentation>) {
213        self.instrumentation = instr;
214    }
215
216    /// Public accessor used by downstream macro helpers (e.g. sntl's
217    /// `__priv::emit_query_macro`). Returns the shared `Arc` so callers can
218    /// emit Sentinel-level events through the same trait.
219    pub fn instrumentation(&self) -> &std::sync::Arc<dyn crate::Instrumentation> {
220        &self.instrumentation
221    }
222
223    /// Crate-internal shorthand for wire sites.
224    pub(crate) fn instr(&self) -> &dyn crate::Instrumentation {
225        &*self.instrumentation
226    }
227}