qail_pg/driver/connection/
startup.rs1use super::helpers::{generate_gss_token, md5_password_message, select_scram_mechanism};
4use super::types::{GSS_SESSION_COUNTER, PgConnection, StartupAuthFlow};
5use crate::driver::stream::PgStream;
6use crate::driver::{
7 AuthSettings, EnterpriseAuthMechanism, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
8};
9use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
10use sha2::{Digest, Sha256};
11use std::sync::atomic::Ordering;
12
13impl PgConnection {
14 pub(super) async fn handle_startup(
16 &mut self,
17 user: &str,
18 password: Option<&str>,
19 auth_settings: AuthSettings,
20 gss_token_provider: Option<GssTokenProvider>,
21 gss_token_provider_ex: Option<GssTokenProviderEx>,
22 ) -> PgResult<()> {
23 let mut scram_client: Option<ScramClient> = None;
24 let mut startup_auth_flow: Option<StartupAuthFlow> = None;
25 let mut saw_auth_ok = false;
26 let gss_session_id = GSS_SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
27 let mut gss_roundtrips: u32 = 0;
28 const MAX_GSS_ROUNDTRIPS: u32 = 32;
29
30 loop {
31 let msg = self.recv().await?;
32 if saw_auth_ok
33 && matches!(
34 &msg,
35 BackendMessage::AuthenticationOk
36 | BackendMessage::AuthenticationKerberosV5
37 | BackendMessage::AuthenticationGSS
38 | BackendMessage::AuthenticationSCMCredential
39 | BackendMessage::AuthenticationGSSContinue(_)
40 | BackendMessage::AuthenticationSSPI
41 | BackendMessage::AuthenticationCleartextPassword
42 | BackendMessage::AuthenticationMD5Password(_)
43 | BackendMessage::AuthenticationSASL(_)
44 | BackendMessage::AuthenticationSASLContinue(_)
45 | BackendMessage::AuthenticationSASLFinal(_)
46 )
47 {
48 return Err(PgError::Protocol(
49 "Received authentication challenge after AuthenticationOk".to_string(),
50 ));
51 }
52 match msg {
53 BackendMessage::AuthenticationOk => {
54 if let Some(StartupAuthFlow::Scram {
55 server_final_seen: false,
56 }) = startup_auth_flow
57 {
58 return Err(PgError::Protocol(
59 "Received AuthenticationOk before AuthenticationSASLFinal".to_string(),
60 ));
61 }
62 saw_auth_ok = true;
63 }
64 BackendMessage::AuthenticationKerberosV5 => {
65 if let Some(flow) = startup_auth_flow {
66 return Err(PgError::Protocol(format!(
67 "Received AuthenticationKerberosV5 while {} authentication is in progress",
68 flow.label()
69 )));
70 }
71 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
72 mechanism: EnterpriseAuthMechanism::KerberosV5,
73 });
74
75 if !auth_settings.allow_kerberos_v5 {
76 return Err(PgError::Auth(
77 "Server requested Kerberos V5 authentication, but Kerberos V5 is disabled by AuthSettings".to_string(),
78 ));
79 }
80
81 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
82 return Err(PgError::Auth(
83 "Kerberos V5 authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
84 ));
85 }
86
87 let token = generate_gss_token(
88 gss_session_id,
89 EnterpriseAuthMechanism::KerberosV5,
90 None,
91 gss_token_provider,
92 gss_token_provider_ex.as_ref(),
93 )
94 .map_err(|e| {
95 PgError::Auth(format!("Kerberos V5 token generation failed: {}", e))
96 })?;
97
98 self.send(FrontendMessage::GSSResponse(token)).await?;
99 }
100 BackendMessage::AuthenticationGSS => {
101 if let Some(flow) = startup_auth_flow {
102 return Err(PgError::Protocol(format!(
103 "Received AuthenticationGSS while {} authentication is in progress",
104 flow.label()
105 )));
106 }
107 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
108 mechanism: EnterpriseAuthMechanism::GssApi,
109 });
110
111 if !auth_settings.allow_gssapi {
112 return Err(PgError::Auth(
113 "Server requested GSSAPI authentication, but GSSAPI is disabled by AuthSettings".to_string(),
114 ));
115 }
116
117 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
118 return Err(PgError::Auth(
119 "GSSAPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
120 ));
121 }
122
123 let token = generate_gss_token(
124 gss_session_id,
125 EnterpriseAuthMechanism::GssApi,
126 None,
127 gss_token_provider,
128 gss_token_provider_ex.as_ref(),
129 )
130 .map_err(|e| {
131 PgError::Auth(format!("GSSAPI initial token generation failed: {}", e))
132 })?;
133
134 self.send(FrontendMessage::GSSResponse(token)).await?;
135 }
136 BackendMessage::AuthenticationSCMCredential => {
137 if let Some(flow) = startup_auth_flow {
138 return Err(PgError::Protocol(format!(
139 "Received AuthenticationSCMCredential while {} authentication is in progress",
140 flow.label()
141 )));
142 }
143 return Err(PgError::Auth(
144 "Server requested SCM credential authentication (auth code 6). This driver currently does not support Unix-socket credential passing; use SCRAM, GSS/SSPI, or password auth for this connection."
145 .to_string(),
146 ));
147 }
148 BackendMessage::AuthenticationSSPI => {
149 if let Some(flow) = startup_auth_flow {
150 return Err(PgError::Protocol(format!(
151 "Received AuthenticationSSPI while {} authentication is in progress",
152 flow.label()
153 )));
154 }
155 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
156 mechanism: EnterpriseAuthMechanism::Sspi,
157 });
158
159 if !auth_settings.allow_sspi {
160 return Err(PgError::Auth(
161 "Server requested SSPI authentication, but SSPI is disabled by AuthSettings".to_string(),
162 ));
163 }
164
165 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
166 return Err(PgError::Auth(
167 "SSPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
168 ));
169 }
170
171 let token = generate_gss_token(
172 gss_session_id,
173 EnterpriseAuthMechanism::Sspi,
174 None,
175 gss_token_provider,
176 gss_token_provider_ex.as_ref(),
177 )
178 .map_err(|e| {
179 PgError::Auth(format!("SSPI initial token generation failed: {}", e))
180 })?;
181
182 self.send(FrontendMessage::GSSResponse(token)).await?;
183 }
184 BackendMessage::AuthenticationGSSContinue(server_token) => {
185 gss_roundtrips += 1;
186 if gss_roundtrips > MAX_GSS_ROUNDTRIPS {
187 return Err(PgError::Auth(format!(
188 "GSS handshake exceeded {} roundtrips — aborting",
189 MAX_GSS_ROUNDTRIPS
190 )));
191 }
192
193 let mechanism = match startup_auth_flow {
194 Some(StartupAuthFlow::EnterpriseGss { mechanism }) => mechanism,
195 Some(flow) => {
196 return Err(PgError::Protocol(format!(
197 "Received AuthenticationGSSContinue while {} authentication is in progress",
198 flow.label()
199 )));
200 }
201 None => {
202 return Err(PgError::Auth(
203 "Received GSSContinue without AuthenticationGSS/SSPI/KerberosV5 init"
204 .to_string(),
205 ));
206 }
207 };
208
209 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
210 return Err(PgError::Auth(
211 "Received GSSContinue but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
212 ));
213 }
214
215 let token = generate_gss_token(
216 gss_session_id,
217 mechanism,
218 Some(&server_token),
219 gss_token_provider,
220 gss_token_provider_ex.as_ref(),
221 )
222 .map_err(|e| {
223 PgError::Auth(format!("GSS continue token generation failed: {}", e))
224 })?;
225
226 if !token.is_empty() {
233 self.send(FrontendMessage::GSSResponse(token)).await?;
234 }
235 }
236 BackendMessage::AuthenticationCleartextPassword => {
237 if let Some(flow) = startup_auth_flow {
238 return Err(PgError::Protocol(format!(
239 "Received AuthenticationCleartextPassword while {} authentication is in progress",
240 flow.label()
241 )));
242 }
243 startup_auth_flow = Some(StartupAuthFlow::CleartextPassword);
244
245 if !auth_settings.allow_cleartext_password {
246 return Err(PgError::Auth(
247 "Server requested cleartext authentication, but cleartext is disabled by AuthSettings"
248 .to_string(),
249 ));
250 }
251 let password = password.ok_or_else(|| {
252 PgError::Auth("Password required for cleartext authentication".to_string())
253 })?;
254 self.send(FrontendMessage::PasswordMessage(password.to_string()))
255 .await?;
256 }
257 BackendMessage::AuthenticationMD5Password(salt) => {
258 if let Some(flow) = startup_auth_flow {
259 return Err(PgError::Protocol(format!(
260 "Received AuthenticationMD5Password while {} authentication is in progress",
261 flow.label()
262 )));
263 }
264 startup_auth_flow = Some(StartupAuthFlow::Md5Password);
265
266 if !auth_settings.allow_md5_password {
267 return Err(PgError::Auth(
268 "Server requested MD5 authentication, but MD5 is disabled by AuthSettings"
269 .to_string(),
270 ));
271 }
272 let password = password.ok_or_else(|| {
273 PgError::Auth("Password required for MD5 authentication".to_string())
274 })?;
275 let md5_password = md5_password_message(user, password, salt);
276 self.send(FrontendMessage::PasswordMessage(md5_password))
277 .await?;
278 }
279 BackendMessage::AuthenticationSASL(mechanisms) => {
280 if let Some(flow) = startup_auth_flow {
281 return Err(PgError::Protocol(format!(
282 "Received AuthenticationSASL while {} authentication is in progress",
283 flow.label()
284 )));
285 }
286 startup_auth_flow = Some(StartupAuthFlow::Scram {
287 server_final_seen: false,
288 });
289
290 if !auth_settings.allow_scram_sha_256 {
291 return Err(PgError::Auth(
292 "Server requested SCRAM authentication, but SCRAM is disabled by AuthSettings"
293 .to_string(),
294 ));
295 }
296 let password = password.ok_or_else(|| {
297 PgError::Auth("Password required for SCRAM authentication".to_string())
298 })?;
299
300 let tls_binding = self.tls_server_end_point_channel_binding();
301 let (mechanism, channel_binding_data) = select_scram_mechanism(
302 &mechanisms,
303 tls_binding,
304 auth_settings.channel_binding,
305 )
306 .map_err(PgError::Auth)?;
307
308 let client = if let Some(binding_data) = channel_binding_data {
309 ScramClient::new_with_tls_server_end_point(user, password, binding_data)
310 } else {
311 ScramClient::new(user, password)
312 };
313 let first_message = client.client_first_message();
314
315 self.send(FrontendMessage::SASLInitialResponse {
316 mechanism,
317 data: first_message,
318 })
319 .await?;
320
321 scram_client = Some(client);
322 }
323 BackendMessage::AuthenticationSASLContinue(server_data) => {
324 match startup_auth_flow {
325 Some(StartupAuthFlow::Scram {
326 server_final_seen: false,
327 }) => {}
328 Some(StartupAuthFlow::Scram {
329 server_final_seen: true,
330 }) => {
331 return Err(PgError::Protocol(
332 "Received AuthenticationSASLContinue after AuthenticationSASLFinal"
333 .to_string(),
334 ));
335 }
336 Some(flow) => {
337 return Err(PgError::Protocol(format!(
338 "Received AuthenticationSASLContinue while {} authentication is in progress",
339 flow.label()
340 )));
341 }
342 None => {
343 return Err(PgError::Auth(
344 "Received SASL Continue without SASL init".to_string(),
345 ));
346 }
347 }
348
349 let client = scram_client.as_mut().ok_or_else(|| {
350 PgError::Auth("Received SASL Continue without SASL init".to_string())
351 })?;
352
353 let final_message = client
354 .process_server_first(&server_data)
355 .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
356
357 self.send(FrontendMessage::SASLResponse(final_message))
358 .await?;
359 }
360 BackendMessage::AuthenticationSASLFinal(server_signature) => {
361 match startup_auth_flow {
362 Some(StartupAuthFlow::Scram {
363 server_final_seen: false,
364 }) => {
365 startup_auth_flow = Some(StartupAuthFlow::Scram {
366 server_final_seen: true,
367 });
368 }
369 Some(StartupAuthFlow::Scram {
370 server_final_seen: true,
371 }) => {
372 return Err(PgError::Protocol(
373 "Received duplicate AuthenticationSASLFinal".to_string(),
374 ));
375 }
376 Some(flow) => {
377 return Err(PgError::Protocol(format!(
378 "Received AuthenticationSASLFinal while {} authentication is in progress",
379 flow.label()
380 )));
381 }
382 None => {
383 return Err(PgError::Auth(
384 "Received SASL Final without SASL init".to_string(),
385 ));
386 }
387 }
388
389 let client = scram_client.as_ref().ok_or_else(|| {
390 PgError::Auth("Received SASL Final without SASL init".to_string())
391 })?;
392 client
393 .verify_server_final(&server_signature)
394 .map_err(|e| PgError::Auth(format!("Server verification failed: {}", e)))?;
395 }
396 BackendMessage::ParameterStatus { .. } => {
397 if !saw_auth_ok {
398 return Err(PgError::Protocol(
399 "Received ParameterStatus before AuthenticationOk".to_string(),
400 ));
401 }
402 }
403 BackendMessage::NegotiateProtocolVersion {
404 newest_minor_supported,
405 unrecognized_protocol_options,
406 } => {
407 if saw_auth_ok {
408 return Err(PgError::Protocol(
409 "Received NegotiateProtocolVersion after AuthenticationOk".to_string(),
410 ));
411 }
412 let negotiated = u16::try_from(newest_minor_supported).map_err(|_| {
413 PgError::Protocol(format!(
414 "Invalid NegotiateProtocolVersion newest_minor_supported: {}",
415 newest_minor_supported
416 ))
417 })?;
418 if negotiated > self.requested_protocol_minor {
419 return Err(PgError::Protocol(format!(
420 "Server negotiated protocol minor {} above requested {}",
421 negotiated, self.requested_protocol_minor
422 )));
423 }
424 self.negotiated_protocol_minor = negotiated;
425 if !unrecognized_protocol_options.is_empty() {
426 tracing::debug!(
427 negotiated_minor = negotiated,
428 unrecognized_count = unrecognized_protocol_options.len(),
429 "startup_negotiate_protocol_version"
430 );
431 }
432 }
433 BackendMessage::BackendKeyData {
434 process_id,
435 secret_key,
436 } => {
437 if !saw_auth_ok {
438 return Err(PgError::Protocol(
439 "Received BackendKeyData before AuthenticationOk".to_string(),
440 ));
441 }
442 self.process_id = process_id;
443 self.cancel_key_bytes = secret_key;
444 self.secret_key = if self.cancel_key_bytes.len() == 4 {
445 i32::from_be_bytes([
446 self.cancel_key_bytes[0],
447 self.cancel_key_bytes[1],
448 self.cancel_key_bytes[2],
449 self.cancel_key_bytes[3],
450 ])
451 } else {
452 0
453 };
454 }
455 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
456 | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
457 | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
458 if !saw_auth_ok {
459 return Err(PgError::Protocol(
460 "Startup completed without AuthenticationOk".to_string(),
461 ));
462 }
463 return Ok(());
464 }
465 BackendMessage::ErrorResponse(err) => {
466 return Err(PgError::Connection(err.message));
467 }
468 BackendMessage::NoticeResponse(_) => {}
469 _ => {
470 return Err(PgError::Protocol(
471 "Unexpected backend message during startup".to_string(),
472 ));
473 }
474 }
475 }
476 }
477
478 fn tls_server_end_point_channel_binding(&self) -> Option<Vec<u8>> {
483 let PgStream::Tls(tls) = &self.stream else {
484 return None;
485 };
486
487 let (_, conn) = tls.get_ref();
488 let certs = conn.peer_certificates()?;
489 let leaf_cert = certs.first()?;
490
491 let mut hasher = Sha256::new();
492 hasher.update(leaf_cert.as_ref());
493 Some(hasher.finalize().to_vec())
494 }
495
496 pub async fn close(mut self) -> PgResult<()> {
499 use crate::protocol::PgEncoder;
500
501 let terminate = PgEncoder::encode_terminate();
503 self.write_all_with_timeout(&terminate, "stream write")
504 .await?;
505 self.flush_with_timeout("stream flush").await?;
506
507 Ok(())
508 }
509
510 pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
516
517 pub(crate) fn evict_prepared_if_full(&mut self) {
523 if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
524 if let Some((evicted_hash, evicted_name)) = self.stmt_cache.pop_lru() {
526 self.prepared_statements.remove(&evicted_name);
527 self.column_info_cache.remove(&evicted_hash);
528 self.pending_statement_closes.push(evicted_name);
529 } else {
530 if let Some(key) = self.prepared_statements.keys().next().cloned() {
534 self.prepared_statements.remove(&key);
535 self.pending_statement_closes.push(key);
536 }
537 }
538 }
539 }
540
541 pub(crate) fn clear_prepared_statement_state(&mut self) {
546 self.stmt_cache.clear();
547 self.prepared_statements.clear();
548 self.column_info_cache.clear();
549 self.pending_statement_closes.clear();
550 }
551}