Skip to main content

qail_pg/driver/connection/
startup.rs

1//! Startup handshake — authentication, parameter negotiation, prepared stmt mgmt.
2
3use super::helpers::{generate_gss_token, md5_password_message, select_scram_mechanism};
4use super::types::{GSS_SESSION_COUNTER, PgConnection, StartupAuthFlow};
5use crate::driver::stream::PgStream;
6use crate::driver::{
7    AuthSettings, EnterpriseAuthMechanism, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
8};
9use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
10use sha2::{Digest, Sha256};
11use std::sync::atomic::Ordering;
12
13impl PgConnection {
14    /// Handle startup sequence (auth + params).
15    pub(super) async fn handle_startup(
16        &mut self,
17        user: &str,
18        password: Option<&str>,
19        auth_settings: AuthSettings,
20        gss_token_provider: Option<GssTokenProvider>,
21        gss_token_provider_ex: Option<GssTokenProviderEx>,
22    ) -> PgResult<()> {
23        let mut scram_client: Option<ScramClient> = None;
24        let mut startup_auth_flow: Option<StartupAuthFlow> = None;
25        let mut saw_auth_ok = false;
26        let gss_session_id = GSS_SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
27        let mut gss_roundtrips: u32 = 0;
28        const MAX_GSS_ROUNDTRIPS: u32 = 32;
29
30        loop {
31            let msg = self.recv().await?;
32            if saw_auth_ok
33                && matches!(
34                    &msg,
35                    BackendMessage::AuthenticationOk
36                        | BackendMessage::AuthenticationKerberosV5
37                        | BackendMessage::AuthenticationGSS
38                        | BackendMessage::AuthenticationGSSContinue(_)
39                        | BackendMessage::AuthenticationSSPI
40                        | BackendMessage::AuthenticationCleartextPassword
41                        | BackendMessage::AuthenticationMD5Password(_)
42                        | BackendMessage::AuthenticationSASL(_)
43                        | BackendMessage::AuthenticationSASLContinue(_)
44                        | BackendMessage::AuthenticationSASLFinal(_)
45                )
46            {
47                return Err(PgError::Protocol(
48                    "Received authentication challenge after AuthenticationOk".to_string(),
49                ));
50            }
51            match msg {
52                BackendMessage::AuthenticationOk => {
53                    if let Some(StartupAuthFlow::Scram {
54                        server_final_seen: false,
55                    }) = startup_auth_flow
56                    {
57                        return Err(PgError::Protocol(
58                            "Received AuthenticationOk before AuthenticationSASLFinal".to_string(),
59                        ));
60                    }
61                    saw_auth_ok = true;
62                }
63                BackendMessage::AuthenticationKerberosV5 => {
64                    if let Some(flow) = startup_auth_flow {
65                        return Err(PgError::Protocol(format!(
66                            "Received AuthenticationKerberosV5 while {} authentication is in progress",
67                            flow.label()
68                        )));
69                    }
70                    startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
71                        mechanism: EnterpriseAuthMechanism::KerberosV5,
72                    });
73
74                    if !auth_settings.allow_kerberos_v5 {
75                        return Err(PgError::Auth(
76                            "Server requested Kerberos V5 authentication, but Kerberos V5 is disabled by AuthSettings".to_string(),
77                        ));
78                    }
79
80                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
81                        return Err(PgError::Auth(
82                            "Kerberos V5 authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
83                        ));
84                    }
85
86                    let token = generate_gss_token(
87                        gss_session_id,
88                        EnterpriseAuthMechanism::KerberosV5,
89                        None,
90                        gss_token_provider,
91                        gss_token_provider_ex.as_ref(),
92                    )
93                    .map_err(|e| {
94                        PgError::Auth(format!("Kerberos V5 token generation failed: {}", e))
95                    })?;
96
97                    self.send(FrontendMessage::GSSResponse(token)).await?;
98                }
99                BackendMessage::AuthenticationGSS => {
100                    if let Some(flow) = startup_auth_flow {
101                        return Err(PgError::Protocol(format!(
102                            "Received AuthenticationGSS while {} authentication is in progress",
103                            flow.label()
104                        )));
105                    }
106                    startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
107                        mechanism: EnterpriseAuthMechanism::GssApi,
108                    });
109
110                    if !auth_settings.allow_gssapi {
111                        return Err(PgError::Auth(
112                            "Server requested GSSAPI authentication, but GSSAPI is disabled by AuthSettings".to_string(),
113                        ));
114                    }
115
116                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
117                        return Err(PgError::Auth(
118                            "GSSAPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
119                        ));
120                    }
121
122                    let token = generate_gss_token(
123                        gss_session_id,
124                        EnterpriseAuthMechanism::GssApi,
125                        None,
126                        gss_token_provider,
127                        gss_token_provider_ex.as_ref(),
128                    )
129                    .map_err(|e| {
130                        PgError::Auth(format!("GSSAPI initial token generation failed: {}", e))
131                    })?;
132
133                    self.send(FrontendMessage::GSSResponse(token)).await?;
134                }
135                BackendMessage::AuthenticationSSPI => {
136                    if let Some(flow) = startup_auth_flow {
137                        return Err(PgError::Protocol(format!(
138                            "Received AuthenticationSSPI while {} authentication is in progress",
139                            flow.label()
140                        )));
141                    }
142                    startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
143                        mechanism: EnterpriseAuthMechanism::Sspi,
144                    });
145
146                    if !auth_settings.allow_sspi {
147                        return Err(PgError::Auth(
148                            "Server requested SSPI authentication, but SSPI is disabled by AuthSettings".to_string(),
149                        ));
150                    }
151
152                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
153                        return Err(PgError::Auth(
154                            "SSPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
155                        ));
156                    }
157
158                    let token = generate_gss_token(
159                        gss_session_id,
160                        EnterpriseAuthMechanism::Sspi,
161                        None,
162                        gss_token_provider,
163                        gss_token_provider_ex.as_ref(),
164                    )
165                    .map_err(|e| {
166                        PgError::Auth(format!("SSPI initial token generation failed: {}", e))
167                    })?;
168
169                    self.send(FrontendMessage::GSSResponse(token)).await?;
170                }
171                BackendMessage::AuthenticationGSSContinue(server_token) => {
172                    gss_roundtrips += 1;
173                    if gss_roundtrips > MAX_GSS_ROUNDTRIPS {
174                        return Err(PgError::Auth(format!(
175                            "GSS handshake exceeded {} roundtrips — aborting",
176                            MAX_GSS_ROUNDTRIPS
177                        )));
178                    }
179
180                    let mechanism = match startup_auth_flow {
181                        Some(StartupAuthFlow::EnterpriseGss { mechanism }) => mechanism,
182                        Some(flow) => {
183                            return Err(PgError::Protocol(format!(
184                                "Received AuthenticationGSSContinue while {} authentication is in progress",
185                                flow.label()
186                            )));
187                        }
188                        None => {
189                            return Err(PgError::Auth(
190                                "Received GSSContinue without AuthenticationGSS/SSPI/KerberosV5 init"
191                                    .to_string(),
192                            ));
193                        }
194                    };
195
196                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
197                        return Err(PgError::Auth(
198                            "Received GSSContinue but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
199                        ));
200                    }
201
202                    let token = generate_gss_token(
203                        gss_session_id,
204                        mechanism,
205                        Some(&server_token),
206                        gss_token_provider,
207                        gss_token_provider_ex.as_ref(),
208                    )
209                    .map_err(|e| {
210                        PgError::Auth(format!("GSS continue token generation failed: {}", e))
211                    })?;
212
213                    // Only send the response if there is actually a token to
214                    // send.  When gss_init_sec_context returns GSS_S_COMPLETE
215                    // on the final round, the token may be empty.  Sending an
216                    // empty GSSResponse ('p') after the server already
217                    // considers auth complete trips the "invalid frontend
218                    // message type 112" FATAL in PostgreSQL.
219                    if !token.is_empty() {
220                        self.send(FrontendMessage::GSSResponse(token)).await?;
221                    }
222                }
223                BackendMessage::AuthenticationCleartextPassword => {
224                    if let Some(flow) = startup_auth_flow {
225                        return Err(PgError::Protocol(format!(
226                            "Received AuthenticationCleartextPassword while {} authentication is in progress",
227                            flow.label()
228                        )));
229                    }
230                    startup_auth_flow = Some(StartupAuthFlow::CleartextPassword);
231
232                    if !auth_settings.allow_cleartext_password {
233                        return Err(PgError::Auth(
234                            "Server requested cleartext authentication, but cleartext is disabled by AuthSettings"
235                                .to_string(),
236                        ));
237                    }
238                    let password = password.ok_or_else(|| {
239                        PgError::Auth("Password required for cleartext authentication".to_string())
240                    })?;
241                    self.send(FrontendMessage::PasswordMessage(password.to_string()))
242                        .await?;
243                }
244                BackendMessage::AuthenticationMD5Password(salt) => {
245                    if let Some(flow) = startup_auth_flow {
246                        return Err(PgError::Protocol(format!(
247                            "Received AuthenticationMD5Password while {} authentication is in progress",
248                            flow.label()
249                        )));
250                    }
251                    startup_auth_flow = Some(StartupAuthFlow::Md5Password);
252
253                    if !auth_settings.allow_md5_password {
254                        return Err(PgError::Auth(
255                            "Server requested MD5 authentication, but MD5 is disabled by AuthSettings"
256                                .to_string(),
257                        ));
258                    }
259                    let password = password.ok_or_else(|| {
260                        PgError::Auth("Password required for MD5 authentication".to_string())
261                    })?;
262                    let md5_password = md5_password_message(user, password, salt);
263                    self.send(FrontendMessage::PasswordMessage(md5_password))
264                        .await?;
265                }
266                BackendMessage::AuthenticationSASL(mechanisms) => {
267                    if let Some(flow) = startup_auth_flow {
268                        return Err(PgError::Protocol(format!(
269                            "Received AuthenticationSASL while {} authentication is in progress",
270                            flow.label()
271                        )));
272                    }
273                    startup_auth_flow = Some(StartupAuthFlow::Scram {
274                        server_final_seen: false,
275                    });
276
277                    if !auth_settings.allow_scram_sha_256 {
278                        return Err(PgError::Auth(
279                            "Server requested SCRAM authentication, but SCRAM is disabled by AuthSettings"
280                                .to_string(),
281                        ));
282                    }
283                    let password = password.ok_or_else(|| {
284                        PgError::Auth("Password required for SCRAM authentication".to_string())
285                    })?;
286
287                    let tls_binding = self.tls_server_end_point_channel_binding();
288                    let (mechanism, channel_binding_data) = select_scram_mechanism(
289                        &mechanisms,
290                        tls_binding,
291                        auth_settings.channel_binding,
292                    )
293                    .map_err(PgError::Auth)?;
294
295                    let client = if let Some(binding_data) = channel_binding_data {
296                        ScramClient::new_with_tls_server_end_point(user, password, binding_data)
297                    } else {
298                        ScramClient::new(user, password)
299                    };
300                    let first_message = client.client_first_message();
301
302                    self.send(FrontendMessage::SASLInitialResponse {
303                        mechanism,
304                        data: first_message,
305                    })
306                    .await?;
307
308                    scram_client = Some(client);
309                }
310                BackendMessage::AuthenticationSASLContinue(server_data) => {
311                    match startup_auth_flow {
312                        Some(StartupAuthFlow::Scram {
313                            server_final_seen: false,
314                        }) => {}
315                        Some(StartupAuthFlow::Scram {
316                            server_final_seen: true,
317                        }) => {
318                            return Err(PgError::Protocol(
319                                "Received AuthenticationSASLContinue after AuthenticationSASLFinal"
320                                    .to_string(),
321                            ));
322                        }
323                        Some(flow) => {
324                            return Err(PgError::Protocol(format!(
325                                "Received AuthenticationSASLContinue while {} authentication is in progress",
326                                flow.label()
327                            )));
328                        }
329                        None => {
330                            return Err(PgError::Auth(
331                                "Received SASL Continue without SASL init".to_string(),
332                            ));
333                        }
334                    }
335
336                    let client = scram_client.as_mut().ok_or_else(|| {
337                        PgError::Auth("Received SASL Continue without SASL init".to_string())
338                    })?;
339
340                    let final_message = client
341                        .process_server_first(&server_data)
342                        .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
343
344                    self.send(FrontendMessage::SASLResponse(final_message))
345                        .await?;
346                }
347                BackendMessage::AuthenticationSASLFinal(server_signature) => {
348                    match startup_auth_flow {
349                        Some(StartupAuthFlow::Scram {
350                            server_final_seen: false,
351                        }) => {
352                            startup_auth_flow = Some(StartupAuthFlow::Scram {
353                                server_final_seen: true,
354                            });
355                        }
356                        Some(StartupAuthFlow::Scram {
357                            server_final_seen: true,
358                        }) => {
359                            return Err(PgError::Protocol(
360                                "Received duplicate AuthenticationSASLFinal".to_string(),
361                            ));
362                        }
363                        Some(flow) => {
364                            return Err(PgError::Protocol(format!(
365                                "Received AuthenticationSASLFinal while {} authentication is in progress",
366                                flow.label()
367                            )));
368                        }
369                        None => {
370                            return Err(PgError::Auth(
371                                "Received SASL Final without SASL init".to_string(),
372                            ));
373                        }
374                    }
375
376                    let client = scram_client.as_ref().ok_or_else(|| {
377                        PgError::Auth("Received SASL Final without SASL init".to_string())
378                    })?;
379                    client
380                        .verify_server_final(&server_signature)
381                        .map_err(|e| PgError::Auth(format!("Server verification failed: {}", e)))?;
382                }
383                BackendMessage::ParameterStatus { .. } => {
384                    if !saw_auth_ok {
385                        return Err(PgError::Protocol(
386                            "Received ParameterStatus before AuthenticationOk".to_string(),
387                        ));
388                    }
389                }
390                BackendMessage::BackendKeyData {
391                    process_id,
392                    secret_key,
393                } => {
394                    if !saw_auth_ok {
395                        return Err(PgError::Protocol(
396                            "Received BackendKeyData before AuthenticationOk".to_string(),
397                        ));
398                    }
399                    self.process_id = process_id;
400                    self.secret_key = secret_key;
401                }
402                BackendMessage::ReadyForQuery(TransactionStatus::Idle)
403                | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
404                | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
405                    if !saw_auth_ok {
406                        return Err(PgError::Protocol(
407                            "Startup completed without AuthenticationOk".to_string(),
408                        ));
409                    }
410                    return Ok(());
411                }
412                BackendMessage::ErrorResponse(err) => {
413                    return Err(PgError::Connection(err.message));
414                }
415                BackendMessage::NoticeResponse(_) => {}
416                _ => {
417                    return Err(PgError::Protocol(
418                        "Unexpected backend message during startup".to_string(),
419                    ));
420                }
421            }
422        }
423    }
424
425    /// Build SCRAM `tls-server-end-point` channel-binding bytes from the server leaf cert.
426    ///
427    /// PostgreSQL expects the hash of the peer certificate DER for
428    /// `SCRAM-SHA-256-PLUS` channel binding. We currently use SHA-256 here.
429    fn tls_server_end_point_channel_binding(&self) -> Option<Vec<u8>> {
430        let PgStream::Tls(tls) = &self.stream else {
431            return None;
432        };
433
434        let (_, conn) = tls.get_ref();
435        let certs = conn.peer_certificates()?;
436        let leaf_cert = certs.first()?;
437
438        let mut hasher = Sha256::new();
439        hasher.update(leaf_cert.as_ref());
440        Some(hasher.finalize().to_vec())
441    }
442
443    /// Gracefully close the connection by sending a Terminate message.
444    /// This tells the server we're done and allows proper cleanup.
445    pub async fn close(mut self) -> PgResult<()> {
446        use crate::protocol::PgEncoder;
447
448        // Send Terminate packet ('X')
449        let terminate = PgEncoder::encode_terminate();
450        self.write_all_with_timeout(&terminate, "stream write")
451            .await?;
452        self.flush_with_timeout("stream flush").await?;
453
454        Ok(())
455    }
456
457    /// Maximum prepared statements per connection before LRU eviction kicks in.
458    ///
459    /// This prevents memory spikes from dynamic batch filters generating
460    /// thousands of unique SQL shapes within a single request. Using LRU
461    /// eviction instead of nuclear `.clear()` preserves hot statements.
462    pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
463
464    /// Evict the least-recently-used prepared statement if at capacity.
465    ///
466    /// Called before every new statement registration to enforce
467    /// `MAX_PREPARED_PER_CONN`. Both `stmt_cache` (LRU ordering) and
468    /// `prepared_statements` (name→SQL map) are kept in sync.
469    pub(crate) fn evict_prepared_if_full(&mut self) {
470        if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
471            // Pop the LRU entry from the cache
472            if let Some((evicted_hash, evicted_name)) = self.stmt_cache.pop_lru() {
473                self.prepared_statements.remove(&evicted_name);
474                self.column_info_cache.remove(&evicted_hash);
475                self.pending_statement_closes.push(evicted_name);
476            } else {
477                // stmt_cache is empty but prepared_statements is full —
478                // shouldn't happen in normal flow, but handle defensively
479                // by clearing the oldest entry from the HashMap.
480                if let Some(key) = self.prepared_statements.keys().next().cloned() {
481                    self.prepared_statements.remove(&key);
482                    self.pending_statement_closes.push(key);
483                }
484            }
485        }
486    }
487
488    /// Clear all local prepared-statement state for this connection.
489    ///
490    /// Used by one-shot self-heal paths when server-side statement state
491    /// becomes invalid after DDL or failover.
492    pub(crate) fn clear_prepared_statement_state(&mut self) {
493        self.stmt_cache.clear();
494        self.prepared_statements.clear();
495        self.column_info_cache.clear();
496        self.pending_statement_closes.clear();
497    }
498}