Skip to main content

qail_pg/driver/connection/
startup.rs

1//! Startup handshake — authentication, parameter negotiation, prepared stmt mgmt.
2
3use super::helpers::{generate_gss_token, md5_password_message, select_scram_mechanism};
4use super::types::{GSS_SESSION_COUNTER, PgConnection, StartupAuthFlow};
5use crate::driver::stream::PgStream;
6use crate::driver::{
7    AuthSettings, EnterpriseAuthMechanism, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
8};
9use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
10use sha2::{Digest, Sha256};
11use std::sync::atomic::Ordering;
12
13impl PgConnection {
14    /// Handle startup sequence (auth + params).
15    pub(super) async fn handle_startup(
16        &mut self,
17        user: &str,
18        password: Option<&str>,
19        auth_settings: AuthSettings,
20        gss_token_provider: Option<GssTokenProvider>,
21        gss_token_provider_ex: Option<GssTokenProviderEx>,
22    ) -> PgResult<()> {
23        let mut scram_client: Option<ScramClient> = None;
24        let mut startup_auth_flow: Option<StartupAuthFlow> = None;
25        let mut saw_auth_ok = false;
26        let gss_session_id = GSS_SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
27        let mut gss_roundtrips: u32 = 0;
28        const MAX_GSS_ROUNDTRIPS: u32 = 32;
29
30        loop {
31            let msg = self.recv().await?;
32            if saw_auth_ok
33                && matches!(
34                    &msg,
35                    BackendMessage::AuthenticationOk
36                        | BackendMessage::AuthenticationKerberosV5
37                        | BackendMessage::AuthenticationGSS
38                        | BackendMessage::AuthenticationSCMCredential
39                        | BackendMessage::AuthenticationGSSContinue(_)
40                        | BackendMessage::AuthenticationSSPI
41                        | BackendMessage::AuthenticationCleartextPassword
42                        | BackendMessage::AuthenticationMD5Password(_)
43                        | BackendMessage::AuthenticationSASL(_)
44                        | BackendMessage::AuthenticationSASLContinue(_)
45                        | BackendMessage::AuthenticationSASLFinal(_)
46                )
47            {
48                return Err(PgError::Protocol(
49                    "Received authentication challenge after AuthenticationOk".to_string(),
50                ));
51            }
52            match msg {
53                BackendMessage::AuthenticationOk => {
54                    if let Some(StartupAuthFlow::Scram {
55                        server_final_seen: false,
56                    }) = startup_auth_flow
57                    {
58                        return Err(PgError::Protocol(
59                            "Received AuthenticationOk before AuthenticationSASLFinal".to_string(),
60                        ));
61                    }
62                    saw_auth_ok = true;
63                }
64                BackendMessage::AuthenticationKerberosV5 => {
65                    if let Some(flow) = startup_auth_flow {
66                        return Err(PgError::Protocol(format!(
67                            "Received AuthenticationKerberosV5 while {} authentication is in progress",
68                            flow.label()
69                        )));
70                    }
71                    startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
72                        mechanism: EnterpriseAuthMechanism::KerberosV5,
73                    });
74
75                    if !auth_settings.allow_kerberos_v5 {
76                        return Err(PgError::Auth(
77                            "Server requested Kerberos V5 authentication, but Kerberos V5 is disabled by AuthSettings".to_string(),
78                        ));
79                    }
80
81                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
82                        return Err(PgError::Auth(
83                            "Kerberos V5 authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
84                        ));
85                    }
86
87                    let token = generate_gss_token(
88                        gss_session_id,
89                        EnterpriseAuthMechanism::KerberosV5,
90                        None,
91                        gss_token_provider,
92                        gss_token_provider_ex.as_ref(),
93                    )
94                    .map_err(|e| {
95                        PgError::Auth(format!("Kerberos V5 token generation failed: {}", e))
96                    })?;
97
98                    self.send(FrontendMessage::GSSResponse(token)).await?;
99                }
100                BackendMessage::AuthenticationGSS => {
101                    if let Some(flow) = startup_auth_flow {
102                        return Err(PgError::Protocol(format!(
103                            "Received AuthenticationGSS while {} authentication is in progress",
104                            flow.label()
105                        )));
106                    }
107                    startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
108                        mechanism: EnterpriseAuthMechanism::GssApi,
109                    });
110
111                    if !auth_settings.allow_gssapi {
112                        return Err(PgError::Auth(
113                            "Server requested GSSAPI authentication, but GSSAPI is disabled by AuthSettings".to_string(),
114                        ));
115                    }
116
117                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
118                        return Err(PgError::Auth(
119                            "GSSAPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
120                        ));
121                    }
122
123                    let token = generate_gss_token(
124                        gss_session_id,
125                        EnterpriseAuthMechanism::GssApi,
126                        None,
127                        gss_token_provider,
128                        gss_token_provider_ex.as_ref(),
129                    )
130                    .map_err(|e| {
131                        PgError::Auth(format!("GSSAPI initial token generation failed: {}", e))
132                    })?;
133
134                    self.send(FrontendMessage::GSSResponse(token)).await?;
135                }
136                BackendMessage::AuthenticationSCMCredential => {
137                    if let Some(flow) = startup_auth_flow {
138                        return Err(PgError::Protocol(format!(
139                            "Received AuthenticationSCMCredential while {} authentication is in progress",
140                            flow.label()
141                        )));
142                    }
143                    return Err(PgError::Auth(
144                        "Server requested SCM credential authentication (auth code 6). This driver currently does not support Unix-socket credential passing; use SCRAM, GSS/SSPI, or password auth for this connection."
145                            .to_string(),
146                    ));
147                }
148                BackendMessage::AuthenticationSSPI => {
149                    if let Some(flow) = startup_auth_flow {
150                        return Err(PgError::Protocol(format!(
151                            "Received AuthenticationSSPI while {} authentication is in progress",
152                            flow.label()
153                        )));
154                    }
155                    startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
156                        mechanism: EnterpriseAuthMechanism::Sspi,
157                    });
158
159                    if !auth_settings.allow_sspi {
160                        return Err(PgError::Auth(
161                            "Server requested SSPI authentication, but SSPI is disabled by AuthSettings".to_string(),
162                        ));
163                    }
164
165                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
166                        return Err(PgError::Auth(
167                            "SSPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
168                        ));
169                    }
170
171                    let token = generate_gss_token(
172                        gss_session_id,
173                        EnterpriseAuthMechanism::Sspi,
174                        None,
175                        gss_token_provider,
176                        gss_token_provider_ex.as_ref(),
177                    )
178                    .map_err(|e| {
179                        PgError::Auth(format!("SSPI initial token generation failed: {}", e))
180                    })?;
181
182                    self.send(FrontendMessage::GSSResponse(token)).await?;
183                }
184                BackendMessage::AuthenticationGSSContinue(server_token) => {
185                    gss_roundtrips += 1;
186                    if gss_roundtrips > MAX_GSS_ROUNDTRIPS {
187                        return Err(PgError::Auth(format!(
188                            "GSS handshake exceeded {} roundtrips — aborting",
189                            MAX_GSS_ROUNDTRIPS
190                        )));
191                    }
192
193                    let mechanism = match startup_auth_flow {
194                        Some(StartupAuthFlow::EnterpriseGss { mechanism }) => mechanism,
195                        Some(flow) => {
196                            return Err(PgError::Protocol(format!(
197                                "Received AuthenticationGSSContinue while {} authentication is in progress",
198                                flow.label()
199                            )));
200                        }
201                        None => {
202                            return Err(PgError::Auth(
203                                "Received GSSContinue without AuthenticationGSS/SSPI/KerberosV5 init"
204                                    .to_string(),
205                            ));
206                        }
207                    };
208
209                    if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
210                        return Err(PgError::Auth(
211                            "Received GSSContinue but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
212                        ));
213                    }
214
215                    let token = generate_gss_token(
216                        gss_session_id,
217                        mechanism,
218                        Some(&server_token),
219                        gss_token_provider,
220                        gss_token_provider_ex.as_ref(),
221                    )
222                    .map_err(|e| {
223                        PgError::Auth(format!("GSS continue token generation failed: {}", e))
224                    })?;
225
226                    // Only send the response if there is actually a token to
227                    // send.  When gss_init_sec_context returns GSS_S_COMPLETE
228                    // on the final round, the token may be empty.  Sending an
229                    // empty GSSResponse ('p') after the server already
230                    // considers auth complete trips the "invalid frontend
231                    // message type 112" FATAL in PostgreSQL.
232                    if !token.is_empty() {
233                        self.send(FrontendMessage::GSSResponse(token)).await?;
234                    }
235                }
236                BackendMessage::AuthenticationCleartextPassword => {
237                    if let Some(flow) = startup_auth_flow {
238                        return Err(PgError::Protocol(format!(
239                            "Received AuthenticationCleartextPassword while {} authentication is in progress",
240                            flow.label()
241                        )));
242                    }
243                    startup_auth_flow = Some(StartupAuthFlow::CleartextPassword);
244
245                    if !auth_settings.allow_cleartext_password {
246                        return Err(PgError::Auth(
247                            "Server requested cleartext authentication, but cleartext is disabled by AuthSettings"
248                                .to_string(),
249                        ));
250                    }
251                    let password = password.ok_or_else(|| {
252                        PgError::Auth("Password required for cleartext authentication".to_string())
253                    })?;
254                    self.send(FrontendMessage::PasswordMessage(password.to_string()))
255                        .await?;
256                }
257                BackendMessage::AuthenticationMD5Password(salt) => {
258                    if let Some(flow) = startup_auth_flow {
259                        return Err(PgError::Protocol(format!(
260                            "Received AuthenticationMD5Password while {} authentication is in progress",
261                            flow.label()
262                        )));
263                    }
264                    startup_auth_flow = Some(StartupAuthFlow::Md5Password);
265
266                    if !auth_settings.allow_md5_password {
267                        return Err(PgError::Auth(
268                            "Server requested MD5 authentication, but MD5 is disabled by AuthSettings"
269                                .to_string(),
270                        ));
271                    }
272                    let password = password.ok_or_else(|| {
273                        PgError::Auth("Password required for MD5 authentication".to_string())
274                    })?;
275                    let md5_password = md5_password_message(user, password, salt);
276                    self.send(FrontendMessage::PasswordMessage(md5_password))
277                        .await?;
278                }
279                BackendMessage::AuthenticationSASL(mechanisms) => {
280                    if let Some(flow) = startup_auth_flow {
281                        return Err(PgError::Protocol(format!(
282                            "Received AuthenticationSASL while {} authentication is in progress",
283                            flow.label()
284                        )));
285                    }
286                    startup_auth_flow = Some(StartupAuthFlow::Scram {
287                        server_final_seen: false,
288                    });
289
290                    if !auth_settings.allow_scram_sha_256 {
291                        return Err(PgError::Auth(
292                            "Server requested SCRAM authentication, but SCRAM is disabled by AuthSettings"
293                                .to_string(),
294                        ));
295                    }
296                    let password = password.ok_or_else(|| {
297                        PgError::Auth("Password required for SCRAM authentication".to_string())
298                    })?;
299
300                    let tls_binding = self.tls_server_end_point_channel_binding();
301                    let (mechanism, channel_binding_data) = select_scram_mechanism(
302                        &mechanisms,
303                        tls_binding,
304                        auth_settings.channel_binding,
305                    )
306                    .map_err(PgError::Auth)?;
307
308                    let client = if let Some(binding_data) = channel_binding_data {
309                        ScramClient::new_with_tls_server_end_point(user, password, binding_data)
310                    } else {
311                        ScramClient::new(user, password)
312                    };
313                    let first_message = client.client_first_message();
314
315                    self.send(FrontendMessage::SASLInitialResponse {
316                        mechanism,
317                        data: first_message,
318                    })
319                    .await?;
320
321                    scram_client = Some(client);
322                }
323                BackendMessage::AuthenticationSASLContinue(server_data) => {
324                    match startup_auth_flow {
325                        Some(StartupAuthFlow::Scram {
326                            server_final_seen: false,
327                        }) => {}
328                        Some(StartupAuthFlow::Scram {
329                            server_final_seen: true,
330                        }) => {
331                            return Err(PgError::Protocol(
332                                "Received AuthenticationSASLContinue after AuthenticationSASLFinal"
333                                    .to_string(),
334                            ));
335                        }
336                        Some(flow) => {
337                            return Err(PgError::Protocol(format!(
338                                "Received AuthenticationSASLContinue while {} authentication is in progress",
339                                flow.label()
340                            )));
341                        }
342                        None => {
343                            return Err(PgError::Auth(
344                                "Received SASL Continue without SASL init".to_string(),
345                            ));
346                        }
347                    }
348
349                    let client = scram_client.as_mut().ok_or_else(|| {
350                        PgError::Auth("Received SASL Continue without SASL init".to_string())
351                    })?;
352
353                    let final_message = client
354                        .process_server_first(&server_data)
355                        .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
356
357                    self.send(FrontendMessage::SASLResponse(final_message))
358                        .await?;
359                }
360                BackendMessage::AuthenticationSASLFinal(server_signature) => {
361                    match startup_auth_flow {
362                        Some(StartupAuthFlow::Scram {
363                            server_final_seen: false,
364                        }) => {
365                            startup_auth_flow = Some(StartupAuthFlow::Scram {
366                                server_final_seen: true,
367                            });
368                        }
369                        Some(StartupAuthFlow::Scram {
370                            server_final_seen: true,
371                        }) => {
372                            return Err(PgError::Protocol(
373                                "Received duplicate AuthenticationSASLFinal".to_string(),
374                            ));
375                        }
376                        Some(flow) => {
377                            return Err(PgError::Protocol(format!(
378                                "Received AuthenticationSASLFinal while {} authentication is in progress",
379                                flow.label()
380                            )));
381                        }
382                        None => {
383                            return Err(PgError::Auth(
384                                "Received SASL Final without SASL init".to_string(),
385                            ));
386                        }
387                    }
388
389                    let client = scram_client.as_ref().ok_or_else(|| {
390                        PgError::Auth("Received SASL Final without SASL init".to_string())
391                    })?;
392                    client
393                        .verify_server_final(&server_signature)
394                        .map_err(|e| PgError::Auth(format!("Server verification failed: {}", e)))?;
395                }
396                BackendMessage::ParameterStatus { .. } => {
397                    if !saw_auth_ok {
398                        return Err(PgError::Protocol(
399                            "Received ParameterStatus before AuthenticationOk".to_string(),
400                        ));
401                    }
402                }
403                BackendMessage::NegotiateProtocolVersion {
404                    newest_minor_supported,
405                    unrecognized_protocol_options,
406                } => {
407                    if saw_auth_ok {
408                        return Err(PgError::Protocol(
409                            "Received NegotiateProtocolVersion after AuthenticationOk".to_string(),
410                        ));
411                    }
412                    let negotiated = if let Ok(minor) = u16::try_from(newest_minor_supported) {
413                        minor
414                    } else {
415                        let packed = u32::try_from(newest_minor_supported).map_err(|_| {
416                            PgError::Protocol(format!(
417                                "Invalid NegotiateProtocolVersion newest_minor_supported: {}",
418                                newest_minor_supported
419                            ))
420                        })?;
421                        let major = (packed >> 16) as u16;
422                        let minor = (packed & 0xFFFF) as u16;
423                        if major != 3 {
424                            return Err(PgError::Protocol(format!(
425                                "Invalid NegotiateProtocolVersion newest_minor_supported: {}",
426                                newest_minor_supported
427                            )));
428                        }
429                        minor
430                    };
431                    if negotiated > self.requested_protocol_minor {
432                        return Err(PgError::Protocol(format!(
433                            "Server negotiated protocol minor {} above requested {}",
434                            negotiated, self.requested_protocol_minor
435                        )));
436                    }
437                    self.negotiated_protocol_minor = negotiated;
438                    if !unrecognized_protocol_options.is_empty() {
439                        tracing::debug!(
440                            negotiated_minor = negotiated,
441                            unrecognized_count = unrecognized_protocol_options.len(),
442                            "startup_negotiate_protocol_version"
443                        );
444                    }
445                }
446                BackendMessage::BackendKeyData {
447                    process_id,
448                    secret_key,
449                } => {
450                    if !saw_auth_ok {
451                        return Err(PgError::Protocol(
452                            "Received BackendKeyData before AuthenticationOk".to_string(),
453                        ));
454                    }
455                    self.process_id = process_id;
456                    self.cancel_key_bytes = secret_key;
457                }
458                BackendMessage::ReadyForQuery(TransactionStatus::Idle)
459                | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
460                | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
461                    if !saw_auth_ok {
462                        return Err(PgError::Protocol(
463                            "Startup completed without AuthenticationOk".to_string(),
464                        ));
465                    }
466                    return Ok(());
467                }
468                BackendMessage::ErrorResponse(err) => {
469                    return Err(PgError::Connection(err.message));
470                }
471                BackendMessage::NoticeResponse(_) => {}
472                _ => {
473                    return Err(PgError::Protocol(
474                        "Unexpected backend message during startup".to_string(),
475                    ));
476                }
477            }
478        }
479    }
480
481    /// Build SCRAM `tls-server-end-point` channel-binding bytes from the server leaf cert.
482    ///
483    /// PostgreSQL expects the hash of the peer certificate DER for
484    /// `SCRAM-SHA-256-PLUS` channel binding. We currently use SHA-256 here.
485    fn tls_server_end_point_channel_binding(&self) -> Option<Vec<u8>> {
486        let PgStream::Tls(tls) = &self.stream else {
487            return None;
488        };
489
490        let (_, conn) = tls.get_ref();
491        let certs = conn.peer_certificates()?;
492        let leaf_cert = certs.first()?;
493
494        let mut hasher = Sha256::new();
495        hasher.update(leaf_cert.as_ref());
496        Some(hasher.finalize().to_vec())
497    }
498
499    /// Gracefully close the connection by sending a Terminate message.
500    /// This tells the server we're done and allows proper cleanup.
501    pub async fn close(mut self) -> PgResult<()> {
502        use crate::protocol::PgEncoder;
503
504        // Send Terminate packet ('X')
505        let terminate = PgEncoder::encode_terminate();
506        self.write_all_with_timeout(&terminate, "stream write")
507            .await?;
508        self.flush_with_timeout("stream flush").await?;
509
510        Ok(())
511    }
512
513    /// Maximum prepared statements per connection before LRU eviction kicks in.
514    ///
515    /// This prevents memory spikes from dynamic batch filters generating
516    /// thousands of unique SQL shapes within a single request. Using LRU
517    /// eviction instead of nuclear `.clear()` preserves hot statements.
518    pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
519
520    /// Evict the least-recently-used prepared statement if at capacity.
521    ///
522    /// Called before every new statement registration to enforce
523    /// `MAX_PREPARED_PER_CONN`. Both `stmt_cache` (LRU ordering) and
524    /// `prepared_statements` (name→SQL map) are kept in sync.
525    pub(crate) fn evict_prepared_if_full(&mut self) {
526        if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
527            // Pop the LRU entry from the cache
528            if let Some((evicted_hash, evicted_name)) = self.stmt_cache.pop_lru() {
529                self.prepared_statements.remove(&evicted_name);
530                self.column_info_cache.remove(&evicted_hash);
531                self.pending_statement_closes.push(evicted_name);
532            } else {
533                // stmt_cache is empty but prepared_statements is full —
534                // shouldn't happen in normal flow, but handle defensively
535                // by clearing the oldest entry from the HashMap.
536                if let Some(key) = self.prepared_statements.keys().next().cloned() {
537                    self.prepared_statements.remove(&key);
538                    self.pending_statement_closes.push(key);
539                }
540            }
541        }
542    }
543
544    /// Clear all local prepared-statement state for this connection.
545    ///
546    /// Used by one-shot self-heal paths when server-side statement state
547    /// becomes invalid after DDL or failover.
548    pub(crate) fn clear_prepared_statement_state(&mut self) {
549        self.stmt_cache.clear();
550        self.prepared_statements.clear();
551        self.column_info_cache.clear();
552        self.pending_statement_closes.clear();
553    }
554}