Skip to main content

qail_pg/driver/connection/
startup.rs

1//! Startup handshake — authentication, parameter negotiation, prepared stmt mgmt.
2
3use super::helpers::{generate_gss_token, md5_password_message, select_scram_mechanism};
4use super::types::{GSS_SESSION_COUNTER, PgConnection, StartupAuthFlow};
5use crate::driver::stream::PgStream;
6use crate::driver::{
7    AuthSettings, EnterpriseAuthMechanism, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
8};
9use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
10use sha2::{Digest, Sha256};
11use std::sync::atomic::Ordering;
12
13impl PgConnection {
14    /// Handle startup sequence (auth + params).
15    pub(super) async fn handle_startup(
16        &mut self,
17        user: &str,
18        password: Option<&str>,
19        auth_settings: AuthSettings,
20        gss_token_provider: Option<GssTokenProvider>,
21        gss_token_provider_ex: Option<GssTokenProviderEx>,
22    ) -> PgResult<()> {
23        let mut scram_client: Option<ScramClient> = None;
24        let mut startup_auth_flow: Option<StartupAuthFlow> = None;
25        let mut saw_auth_ok = false;
26        let gss_session_id = GSS_SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
27        let mut gss_roundtrips: u32 = 0;
28        const MAX_GSS_ROUNDTRIPS: u32 = 32;
29
30        loop {
31            let msg = self.recv().await?;
32            if saw_auth_ok
33                && matches!(
34                    &msg,
35                    BackendMessage::AuthenticationOk
36                        | BackendMessage::AuthenticationKerberosV5
37                        | BackendMessage::AuthenticationGSS
38                        | BackendMessage::AuthenticationSCMCredential
39                        | BackendMessage::AuthenticationGSSContinue(_)
40                        | BackendMessage::AuthenticationSSPI
41                        | BackendMessage::AuthenticationCleartextPassword
42                        | BackendMessage::AuthenticationMD5Password(_)
43                        | BackendMessage::AuthenticationSASL(_)
44                        | BackendMessage::AuthenticationSASLContinue(_)
45                        | BackendMessage::AuthenticationSASLFinal(_)
46                )
47            {
48                return Err(PgError::Protocol(
49                    "Received authentication challenge after AuthenticationOk".to_string(),
50                ));
51            }
52            match msg {
53                BackendMessage::AuthenticationOk => {
54                    if let Some(StartupAuthFlow::Scram {
55                        server_final_seen: false,
56                    }) = startup_auth_flow
57                    {
58                        return Err(PgError::Protocol(
59                            "Received AuthenticationOk before AuthenticationSASLFinal".to_string(),
60                        ));
61                    }
62                    saw_auth_ok = true;
63                }
64                BackendMessage::AuthenticationKerberosV5 => {
65                    if let Some(flow) = startup_auth_flow {
66                        return Err(PgError::Protocol(format!(
67                            "Received AuthenticationKerberosV5 while {} authentication is in progress",
68                            flow.label()
69                        )));
70                    }
71                    startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
72                        mechanism: EnterpriseAuthMechanism::KerberosV5,
73                    });
74
75                    if !auth_settings.allow_kerberos_v5 {
76                        return Err(PgError::Auth(
77                            "Server requested Kerberos V5 authentication, but Kerberos V5 is disabled by AuthSettings".to_string(),
78                        ));
79                    }
80
81                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
82                        return Err(PgError::Auth(
83                            "Kerberos V5 authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
84                        ));
85                    }
86
87                    let token = generate_gss_token(
88                        gss_session_id,
89                        EnterpriseAuthMechanism::KerberosV5,
90                        None,
91                        gss_token_provider,
92                        gss_token_provider_ex.as_ref(),
93                    )
94                    .map_err(|e| {
95                        PgError::Auth(format!("Kerberos V5 token generation failed: {}", e))
96                    })?;
97
98                    self.send(FrontendMessage::GSSResponse(token)).await?;
99                }
100                BackendMessage::AuthenticationGSS => {
101                    if let Some(flow) = startup_auth_flow {
102                        return Err(PgError::Protocol(format!(
103                            "Received AuthenticationGSS while {} authentication is in progress",
104                            flow.label()
105                        )));
106                    }
107                    startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
108                        mechanism: EnterpriseAuthMechanism::GssApi,
109                    });
110
111                    if !auth_settings.allow_gssapi {
112                        return Err(PgError::Auth(
113                            "Server requested GSSAPI authentication, but GSSAPI is disabled by AuthSettings".to_string(),
114                        ));
115                    }
116
117                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
118                        return Err(PgError::Auth(
119                            "GSSAPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
120                        ));
121                    }
122
123                    let token = generate_gss_token(
124                        gss_session_id,
125                        EnterpriseAuthMechanism::GssApi,
126                        None,
127                        gss_token_provider,
128                        gss_token_provider_ex.as_ref(),
129                    )
130                    .map_err(|e| {
131                        PgError::Auth(format!("GSSAPI initial token generation failed: {}", e))
132                    })?;
133
134                    self.send(FrontendMessage::GSSResponse(token)).await?;
135                }
136                BackendMessage::AuthenticationSCMCredential => {
137                    if let Some(flow) = startup_auth_flow {
138                        return Err(PgError::Protocol(format!(
139                            "Received AuthenticationSCMCredential while {} authentication is in progress",
140                            flow.label()
141                        )));
142                    }
143                    return Err(PgError::Auth(
144                        "Server requested SCM credential authentication (auth code 6). This driver currently does not support Unix-socket credential passing; use SCRAM, GSS/SSPI, or password auth for this connection."
145                            .to_string(),
146                    ));
147                }
148                BackendMessage::AuthenticationSSPI => {
149                    if let Some(flow) = startup_auth_flow {
150                        return Err(PgError::Protocol(format!(
151                            "Received AuthenticationSSPI while {} authentication is in progress",
152                            flow.label()
153                        )));
154                    }
155                    startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
156                        mechanism: EnterpriseAuthMechanism::Sspi,
157                    });
158
159                    if !auth_settings.allow_sspi {
160                        return Err(PgError::Auth(
161                            "Server requested SSPI authentication, but SSPI is disabled by AuthSettings".to_string(),
162                        ));
163                    }
164
165                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
166                        return Err(PgError::Auth(
167                            "SSPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
168                        ));
169                    }
170
171                    let token = generate_gss_token(
172                        gss_session_id,
173                        EnterpriseAuthMechanism::Sspi,
174                        None,
175                        gss_token_provider,
176                        gss_token_provider_ex.as_ref(),
177                    )
178                    .map_err(|e| {
179                        PgError::Auth(format!("SSPI initial token generation failed: {}", e))
180                    })?;
181
182                    self.send(FrontendMessage::GSSResponse(token)).await?;
183                }
184                BackendMessage::AuthenticationGSSContinue(server_token) => {
185                    gss_roundtrips += 1;
186                    if gss_roundtrips > MAX_GSS_ROUNDTRIPS {
187                        return Err(PgError::Auth(format!(
188                            "GSS handshake exceeded {} roundtrips — aborting",
189                            MAX_GSS_ROUNDTRIPS
190                        )));
191                    }
192
193                    let mechanism = match startup_auth_flow {
194                        Some(StartupAuthFlow::EnterpriseGss { mechanism }) => mechanism,
195                        Some(flow) => {
196                            return Err(PgError::Protocol(format!(
197                                "Received AuthenticationGSSContinue while {} authentication is in progress",
198                                flow.label()
199                            )));
200                        }
201                        None => {
202                            return Err(PgError::Auth(
203                                "Received GSSContinue without AuthenticationGSS/SSPI/KerberosV5 init"
204                                    .to_string(),
205                            ));
206                        }
207                    };
208
209                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
210                        return Err(PgError::Auth(
211                            "Received GSSContinue but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
212                        ));
213                    }
214
215                    let token = generate_gss_token(
216                        gss_session_id,
217                        mechanism,
218                        Some(&server_token),
219                        gss_token_provider,
220                        gss_token_provider_ex.as_ref(),
221                    )
222                    .map_err(|e| {
223                        PgError::Auth(format!("GSS continue token generation failed: {}", e))
224                    })?;
225
226                    // Only send the response if there is actually a token to
227                    // send.  When gss_init_sec_context returns GSS_S_COMPLETE
228                    // on the final round, the token may be empty.  Sending an
229                    // empty GSSResponse ('p') after the server already
230                    // considers auth complete trips the "invalid frontend
231                    // message type 112" FATAL in PostgreSQL.
232                    if !token.is_empty() {
233                        self.send(FrontendMessage::GSSResponse(token)).await?;
234                    }
235                }
236                BackendMessage::AuthenticationCleartextPassword => {
237                    if let Some(flow) = startup_auth_flow {
238                        return Err(PgError::Protocol(format!(
239                            "Received AuthenticationCleartextPassword while {} authentication is in progress",
240                            flow.label()
241                        )));
242                    }
243                    startup_auth_flow = Some(StartupAuthFlow::CleartextPassword);
244
245                    if !auth_settings.allow_cleartext_password {
246                        return Err(PgError::Auth(
247                            "Server requested cleartext authentication, but cleartext is disabled by AuthSettings"
248                                .to_string(),
249                        ));
250                    }
251                    let password = password.ok_or_else(|| {
252                        PgError::Auth("Password required for cleartext authentication".to_string())
253                    })?;
254                    self.send(FrontendMessage::PasswordMessage(password.to_string()))
255                        .await?;
256                }
257                BackendMessage::AuthenticationMD5Password(salt) => {
258                    if let Some(flow) = startup_auth_flow {
259                        return Err(PgError::Protocol(format!(
260                            "Received AuthenticationMD5Password while {} authentication is in progress",
261                            flow.label()
262                        )));
263                    }
264                    startup_auth_flow = Some(StartupAuthFlow::Md5Password);
265
266                    if !auth_settings.allow_md5_password {
267                        return Err(PgError::Auth(
268                            "Server requested MD5 authentication, but MD5 is disabled by AuthSettings"
269                                .to_string(),
270                        ));
271                    }
272                    let password = password.ok_or_else(|| {
273                        PgError::Auth("Password required for MD5 authentication".to_string())
274                    })?;
275                    let md5_password = md5_password_message(user, password, salt);
276                    self.send(FrontendMessage::PasswordMessage(md5_password))
277                        .await?;
278                }
279                BackendMessage::AuthenticationSASL(mechanisms) => {
280                    if let Some(flow) = startup_auth_flow {
281                        return Err(PgError::Protocol(format!(
282                            "Received AuthenticationSASL while {} authentication is in progress",
283                            flow.label()
284                        )));
285                    }
286                    startup_auth_flow = Some(StartupAuthFlow::Scram {
287                        server_final_seen: false,
288                    });
289
290                    if !auth_settings.allow_scram_sha_256 {
291                        return Err(PgError::Auth(
292                            "Server requested SCRAM authentication, but SCRAM is disabled by AuthSettings"
293                                .to_string(),
294                        ));
295                    }
296                    let password = password.ok_or_else(|| {
297                        PgError::Auth("Password required for SCRAM authentication".to_string())
298                    })?;
299
300                    let tls_binding = self.tls_server_end_point_channel_binding();
301                    let (mechanism, channel_binding_data) = select_scram_mechanism(
302                        &mechanisms,
303                        tls_binding,
304                        auth_settings.channel_binding,
305                    )
306                    .map_err(PgError::Auth)?;
307
308                    let client = if let Some(binding_data) = channel_binding_data {
309                        ScramClient::new_with_tls_server_end_point(user, password, binding_data)
310                    } else {
311                        ScramClient::new(user, password)
312                    };
313                    let first_message = client.client_first_message();
314
315                    self.send(FrontendMessage::SASLInitialResponse {
316                        mechanism,
317                        data: first_message,
318                    })
319                    .await?;
320
321                    scram_client = Some(client);
322                }
323                BackendMessage::AuthenticationSASLContinue(server_data) => {
324                    match startup_auth_flow {
325                        Some(StartupAuthFlow::Scram {
326                            server_final_seen: false,
327                        }) => {}
328                        Some(StartupAuthFlow::Scram {
329                            server_final_seen: true,
330                        }) => {
331                            return Err(PgError::Protocol(
332                                "Received AuthenticationSASLContinue after AuthenticationSASLFinal"
333                                    .to_string(),
334                            ));
335                        }
336                        Some(flow) => {
337                            return Err(PgError::Protocol(format!(
338                                "Received AuthenticationSASLContinue while {} authentication is in progress",
339                                flow.label()
340                            )));
341                        }
342                        None => {
343                            return Err(PgError::Auth(
344                                "Received SASL Continue without SASL init".to_string(),
345                            ));
346                        }
347                    }
348
349                    let client = scram_client.as_mut().ok_or_else(|| {
350                        PgError::Auth("Received SASL Continue without SASL init".to_string())
351                    })?;
352
353                    let final_message = client
354                        .process_server_first(&server_data)
355                        .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
356
357                    self.send(FrontendMessage::SASLResponse(final_message))
358                        .await?;
359                }
360                BackendMessage::AuthenticationSASLFinal(server_signature) => {
361                    match startup_auth_flow {
362                        Some(StartupAuthFlow::Scram {
363                            server_final_seen: false,
364                        }) => {
365                            startup_auth_flow = Some(StartupAuthFlow::Scram {
366                                server_final_seen: true,
367                            });
368                        }
369                        Some(StartupAuthFlow::Scram {
370                            server_final_seen: true,
371                        }) => {
372                            return Err(PgError::Protocol(
373                                "Received duplicate AuthenticationSASLFinal".to_string(),
374                            ));
375                        }
376                        Some(flow) => {
377                            return Err(PgError::Protocol(format!(
378                                "Received AuthenticationSASLFinal while {} authentication is in progress",
379                                flow.label()
380                            )));
381                        }
382                        None => {
383                            return Err(PgError::Auth(
384                                "Received SASL Final without SASL init".to_string(),
385                            ));
386                        }
387                    }
388
389                    let client = scram_client.as_ref().ok_or_else(|| {
390                        PgError::Auth("Received SASL Final without SASL init".to_string())
391                    })?;
392                    client
393                        .verify_server_final(&server_signature)
394                        .map_err(|e| PgError::Auth(format!("Server verification failed: {}", e)))?;
395                }
396                BackendMessage::ParameterStatus { .. } => {
397                    if !saw_auth_ok {
398                        return Err(PgError::Protocol(
399                            "Received ParameterStatus before AuthenticationOk".to_string(),
400                        ));
401                    }
402                }
403                BackendMessage::BackendKeyData {
404                    process_id,
405                    secret_key,
406                } => {
407                    if !saw_auth_ok {
408                        return Err(PgError::Protocol(
409                            "Received BackendKeyData before AuthenticationOk".to_string(),
410                        ));
411                    }
412                    self.process_id = process_id;
413                    self.secret_key = secret_key;
414                }
415                BackendMessage::ReadyForQuery(TransactionStatus::Idle)
416                | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
417                | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
418                    if !saw_auth_ok {
419                        return Err(PgError::Protocol(
420                            "Startup completed without AuthenticationOk".to_string(),
421                        ));
422                    }
423                    return Ok(());
424                }
425                BackendMessage::ErrorResponse(err) => {
426                    return Err(PgError::Connection(err.message));
427                }
428                BackendMessage::NoticeResponse(_) => {}
429                _ => {
430                    return Err(PgError::Protocol(
431                        "Unexpected backend message during startup".to_string(),
432                    ));
433                }
434            }
435        }
436    }
437
438    /// Build SCRAM `tls-server-end-point` channel-binding bytes from the server leaf cert.
439    ///
440    /// PostgreSQL expects the hash of the peer certificate DER for
441    /// `SCRAM-SHA-256-PLUS` channel binding. We currently use SHA-256 here.
442    fn tls_server_end_point_channel_binding(&self) -> Option<Vec<u8>> {
443        let PgStream::Tls(tls) = &self.stream else {
444            return None;
445        };
446
447        let (_, conn) = tls.get_ref();
448        let certs = conn.peer_certificates()?;
449        let leaf_cert = certs.first()?;
450
451        let mut hasher = Sha256::new();
452        hasher.update(leaf_cert.as_ref());
453        Some(hasher.finalize().to_vec())
454    }
455
456    /// Gracefully close the connection by sending a Terminate message.
457    /// This tells the server we're done and allows proper cleanup.
458    pub async fn close(mut self) -> PgResult<()> {
459        use crate::protocol::PgEncoder;
460
461        // Send Terminate packet ('X')
462        let terminate = PgEncoder::encode_terminate();
463        self.write_all_with_timeout(&terminate, "stream write")
464            .await?;
465        self.flush_with_timeout("stream flush").await?;
466
467        Ok(())
468    }
469
470    /// Maximum prepared statements per connection before LRU eviction kicks in.
471    ///
472    /// This prevents memory spikes from dynamic batch filters generating
473    /// thousands of unique SQL shapes within a single request. Using LRU
474    /// eviction instead of nuclear `.clear()` preserves hot statements.
475    pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
476
477    /// Evict the least-recently-used prepared statement if at capacity.
478    ///
479    /// Called before every new statement registration to enforce
480    /// `MAX_PREPARED_PER_CONN`. Both `stmt_cache` (LRU ordering) and
481    /// `prepared_statements` (name→SQL map) are kept in sync.
482    pub(crate) fn evict_prepared_if_full(&mut self) {
483        if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
484            // Pop the LRU entry from the cache
485            if let Some((evicted_hash, evicted_name)) = self.stmt_cache.pop_lru() {
486                self.prepared_statements.remove(&evicted_name);
487                self.column_info_cache.remove(&evicted_hash);
488                self.pending_statement_closes.push(evicted_name);
489            } else {
490                // stmt_cache is empty but prepared_statements is full —
491                // shouldn't happen in normal flow, but handle defensively
492                // by clearing the oldest entry from the HashMap.
493                if let Some(key) = self.prepared_statements.keys().next().cloned() {
494                    self.prepared_statements.remove(&key);
495                    self.pending_statement_closes.push(key);
496                }
497            }
498        }
499    }
500
501    /// Clear all local prepared-statement state for this connection.
502    ///
503    /// Used by one-shot self-heal paths when server-side statement state
504    /// becomes invalid after DDL or failover.
505    pub(crate) fn clear_prepared_statement_state(&mut self) {
506        self.stmt_cache.clear();
507        self.prepared_statements.clear();
508        self.column_info_cache.clear();
509        self.pending_statement_closes.clear();
510    }
511}