qail_pg/driver/connection/
startup.rs1use super::helpers::{generate_gss_token, md5_password_message, select_scram_mechanism};
4use super::types::{GSS_SESSION_COUNTER, PgConnection, StartupAuthFlow};
5use crate::driver::stream::PgStream;
6use crate::driver::{
7 AuthSettings, EnterpriseAuthMechanism, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
8};
9use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
10use sha2::{Digest, Sha256};
11use std::sync::atomic::Ordering;
12
13impl PgConnection {
14 pub(super) async fn handle_startup(
16 &mut self,
17 user: &str,
18 password: Option<&str>,
19 auth_settings: AuthSettings,
20 gss_token_provider: Option<GssTokenProvider>,
21 gss_token_provider_ex: Option<GssTokenProviderEx>,
22 ) -> PgResult<()> {
23 let mut scram_client: Option<ScramClient> = None;
24 let mut startup_auth_flow: Option<StartupAuthFlow> = None;
25 let mut saw_auth_ok = false;
26 let gss_session_id = GSS_SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
27 let mut gss_roundtrips: u32 = 0;
28 const MAX_GSS_ROUNDTRIPS: u32 = 32;
29
30 loop {
31 let msg = self.recv().await?;
32 if saw_auth_ok
33 && matches!(
34 &msg,
35 BackendMessage::AuthenticationOk
36 | BackendMessage::AuthenticationKerberosV5
37 | BackendMessage::AuthenticationGSS
38 | BackendMessage::AuthenticationSCMCredential
39 | BackendMessage::AuthenticationGSSContinue(_)
40 | BackendMessage::AuthenticationSSPI
41 | BackendMessage::AuthenticationCleartextPassword
42 | BackendMessage::AuthenticationMD5Password(_)
43 | BackendMessage::AuthenticationSASL(_)
44 | BackendMessage::AuthenticationSASLContinue(_)
45 | BackendMessage::AuthenticationSASLFinal(_)
46 )
47 {
48 return Err(PgError::Protocol(
49 "Received authentication challenge after AuthenticationOk".to_string(),
50 ));
51 }
52 match msg {
53 BackendMessage::AuthenticationOk => {
54 if let Some(StartupAuthFlow::Scram {
55 server_final_seen: false,
56 }) = startup_auth_flow
57 {
58 return Err(PgError::Protocol(
59 "Received AuthenticationOk before AuthenticationSASLFinal".to_string(),
60 ));
61 }
62 saw_auth_ok = true;
63 }
64 BackendMessage::AuthenticationKerberosV5 => {
65 if let Some(flow) = startup_auth_flow {
66 return Err(PgError::Protocol(format!(
67 "Received AuthenticationKerberosV5 while {} authentication is in progress",
68 flow.label()
69 )));
70 }
71 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
72 mechanism: EnterpriseAuthMechanism::KerberosV5,
73 });
74
75 if !auth_settings.allow_kerberos_v5 {
76 return Err(PgError::Auth(
77 "Server requested Kerberos V5 authentication, but Kerberos V5 is disabled by AuthSettings".to_string(),
78 ));
79 }
80
81 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
82 return Err(PgError::Auth(
83 "Kerberos V5 authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
84 ));
85 }
86
87 let token = generate_gss_token(
88 gss_session_id,
89 EnterpriseAuthMechanism::KerberosV5,
90 None,
91 gss_token_provider,
92 gss_token_provider_ex.as_ref(),
93 )
94 .map_err(|e| {
95 PgError::Auth(format!("Kerberos V5 token generation failed: {}", e))
96 })?;
97
98 self.send(FrontendMessage::GSSResponse(token)).await?;
99 }
100 BackendMessage::AuthenticationGSS => {
101 if let Some(flow) = startup_auth_flow {
102 return Err(PgError::Protocol(format!(
103 "Received AuthenticationGSS while {} authentication is in progress",
104 flow.label()
105 )));
106 }
107 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
108 mechanism: EnterpriseAuthMechanism::GssApi,
109 });
110
111 if !auth_settings.allow_gssapi {
112 return Err(PgError::Auth(
113 "Server requested GSSAPI authentication, but GSSAPI is disabled by AuthSettings".to_string(),
114 ));
115 }
116
117 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
118 return Err(PgError::Auth(
119 "GSSAPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
120 ));
121 }
122
123 let token = generate_gss_token(
124 gss_session_id,
125 EnterpriseAuthMechanism::GssApi,
126 None,
127 gss_token_provider,
128 gss_token_provider_ex.as_ref(),
129 )
130 .map_err(|e| {
131 PgError::Auth(format!("GSSAPI initial token generation failed: {}", e))
132 })?;
133
134 self.send(FrontendMessage::GSSResponse(token)).await?;
135 }
136 BackendMessage::AuthenticationSCMCredential => {
137 if let Some(flow) = startup_auth_flow {
138 return Err(PgError::Protocol(format!(
139 "Received AuthenticationSCMCredential while {} authentication is in progress",
140 flow.label()
141 )));
142 }
143 return Err(PgError::Auth(
144 "Server requested SCM credential authentication (auth code 6). This driver currently does not support Unix-socket credential passing; use SCRAM, GSS/SSPI, or password auth for this connection."
145 .to_string(),
146 ));
147 }
148 BackendMessage::AuthenticationSSPI => {
149 if let Some(flow) = startup_auth_flow {
150 return Err(PgError::Protocol(format!(
151 "Received AuthenticationSSPI while {} authentication is in progress",
152 flow.label()
153 )));
154 }
155 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
156 mechanism: EnterpriseAuthMechanism::Sspi,
157 });
158
159 if !auth_settings.allow_sspi {
160 return Err(PgError::Auth(
161 "Server requested SSPI authentication, but SSPI is disabled by AuthSettings".to_string(),
162 ));
163 }
164
165 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
166 return Err(PgError::Auth(
167 "SSPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
168 ));
169 }
170
171 let token = generate_gss_token(
172 gss_session_id,
173 EnterpriseAuthMechanism::Sspi,
174 None,
175 gss_token_provider,
176 gss_token_provider_ex.as_ref(),
177 )
178 .map_err(|e| {
179 PgError::Auth(format!("SSPI initial token generation failed: {}", e))
180 })?;
181
182 self.send(FrontendMessage::GSSResponse(token)).await?;
183 }
184 BackendMessage::AuthenticationGSSContinue(server_token) => {
185 gss_roundtrips += 1;
186 if gss_roundtrips > MAX_GSS_ROUNDTRIPS {
187 return Err(PgError::Auth(format!(
188 "GSS handshake exceeded {} roundtrips — aborting",
189 MAX_GSS_ROUNDTRIPS
190 )));
191 }
192
193 let mechanism = match startup_auth_flow {
194 Some(StartupAuthFlow::EnterpriseGss { mechanism }) => mechanism,
195 Some(flow) => {
196 return Err(PgError::Protocol(format!(
197 "Received AuthenticationGSSContinue while {} authentication is in progress",
198 flow.label()
199 )));
200 }
201 None => {
202 return Err(PgError::Auth(
203 "Received GSSContinue without AuthenticationGSS/SSPI/KerberosV5 init"
204 .to_string(),
205 ));
206 }
207 };
208
209 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
210 return Err(PgError::Auth(
211 "Received GSSContinue but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
212 ));
213 }
214
215 let token = generate_gss_token(
216 gss_session_id,
217 mechanism,
218 Some(&server_token),
219 gss_token_provider,
220 gss_token_provider_ex.as_ref(),
221 )
222 .map_err(|e| {
223 PgError::Auth(format!("GSS continue token generation failed: {}", e))
224 })?;
225
226 if !token.is_empty() {
233 self.send(FrontendMessage::GSSResponse(token)).await?;
234 }
235 }
236 BackendMessage::AuthenticationCleartextPassword => {
237 if let Some(flow) = startup_auth_flow {
238 return Err(PgError::Protocol(format!(
239 "Received AuthenticationCleartextPassword while {} authentication is in progress",
240 flow.label()
241 )));
242 }
243 startup_auth_flow = Some(StartupAuthFlow::CleartextPassword);
244
245 if !auth_settings.allow_cleartext_password {
246 return Err(PgError::Auth(
247 "Server requested cleartext authentication, but cleartext is disabled by AuthSettings"
248 .to_string(),
249 ));
250 }
251 let password = password.ok_or_else(|| {
252 PgError::Auth("Password required for cleartext authentication".to_string())
253 })?;
254 self.send(FrontendMessage::PasswordMessage(password.to_string()))
255 .await?;
256 }
257 BackendMessage::AuthenticationMD5Password(salt) => {
258 if let Some(flow) = startup_auth_flow {
259 return Err(PgError::Protocol(format!(
260 "Received AuthenticationMD5Password while {} authentication is in progress",
261 flow.label()
262 )));
263 }
264 startup_auth_flow = Some(StartupAuthFlow::Md5Password);
265
266 if !auth_settings.allow_md5_password {
267 return Err(PgError::Auth(
268 "Server requested MD5 authentication, but MD5 is disabled by AuthSettings"
269 .to_string(),
270 ));
271 }
272 let password = password.ok_or_else(|| {
273 PgError::Auth("Password required for MD5 authentication".to_string())
274 })?;
275 let md5_password = md5_password_message(user, password, salt);
276 self.send(FrontendMessage::PasswordMessage(md5_password))
277 .await?;
278 }
279 BackendMessage::AuthenticationSASL(mechanisms) => {
280 if let Some(flow) = startup_auth_flow {
281 return Err(PgError::Protocol(format!(
282 "Received AuthenticationSASL while {} authentication is in progress",
283 flow.label()
284 )));
285 }
286 startup_auth_flow = Some(StartupAuthFlow::Scram {
287 server_final_seen: false,
288 });
289
290 if !auth_settings.allow_scram_sha_256 {
291 return Err(PgError::Auth(
292 "Server requested SCRAM authentication, but SCRAM is disabled by AuthSettings"
293 .to_string(),
294 ));
295 }
296 let password = password.ok_or_else(|| {
297 PgError::Auth("Password required for SCRAM authentication".to_string())
298 })?;
299
300 let tls_binding = self.tls_server_end_point_channel_binding();
301 let (mechanism, channel_binding_data) = select_scram_mechanism(
302 &mechanisms,
303 tls_binding,
304 auth_settings.channel_binding,
305 )
306 .map_err(PgError::Auth)?;
307
308 let client = if let Some(binding_data) = channel_binding_data {
309 ScramClient::new_with_tls_server_end_point(user, password, binding_data)
310 } else {
311 ScramClient::new(user, password)
312 };
313 let first_message = client.client_first_message();
314
315 self.send(FrontendMessage::SASLInitialResponse {
316 mechanism,
317 data: first_message,
318 })
319 .await?;
320
321 scram_client = Some(client);
322 }
323 BackendMessage::AuthenticationSASLContinue(server_data) => {
324 match startup_auth_flow {
325 Some(StartupAuthFlow::Scram {
326 server_final_seen: false,
327 }) => {}
328 Some(StartupAuthFlow::Scram {
329 server_final_seen: true,
330 }) => {
331 return Err(PgError::Protocol(
332 "Received AuthenticationSASLContinue after AuthenticationSASLFinal"
333 .to_string(),
334 ));
335 }
336 Some(flow) => {
337 return Err(PgError::Protocol(format!(
338 "Received AuthenticationSASLContinue while {} authentication is in progress",
339 flow.label()
340 )));
341 }
342 None => {
343 return Err(PgError::Auth(
344 "Received SASL Continue without SASL init".to_string(),
345 ));
346 }
347 }
348
349 let client = scram_client.as_mut().ok_or_else(|| {
350 PgError::Auth("Received SASL Continue without SASL init".to_string())
351 })?;
352
353 let final_message = client
354 .process_server_first(&server_data)
355 .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
356
357 self.send(FrontendMessage::SASLResponse(final_message))
358 .await?;
359 }
360 BackendMessage::AuthenticationSASLFinal(server_signature) => {
361 match startup_auth_flow {
362 Some(StartupAuthFlow::Scram {
363 server_final_seen: false,
364 }) => {
365 startup_auth_flow = Some(StartupAuthFlow::Scram {
366 server_final_seen: true,
367 });
368 }
369 Some(StartupAuthFlow::Scram {
370 server_final_seen: true,
371 }) => {
372 return Err(PgError::Protocol(
373 "Received duplicate AuthenticationSASLFinal".to_string(),
374 ));
375 }
376 Some(flow) => {
377 return Err(PgError::Protocol(format!(
378 "Received AuthenticationSASLFinal while {} authentication is in progress",
379 flow.label()
380 )));
381 }
382 None => {
383 return Err(PgError::Auth(
384 "Received SASL Final without SASL init".to_string(),
385 ));
386 }
387 }
388
389 let client = scram_client.as_ref().ok_or_else(|| {
390 PgError::Auth("Received SASL Final without SASL init".to_string())
391 })?;
392 client
393 .verify_server_final(&server_signature)
394 .map_err(|e| PgError::Auth(format!("Server verification failed: {}", e)))?;
395 }
396 BackendMessage::ParameterStatus { .. } => {
397 if !saw_auth_ok {
398 return Err(PgError::Protocol(
399 "Received ParameterStatus before AuthenticationOk".to_string(),
400 ));
401 }
402 }
403 BackendMessage::NegotiateProtocolVersion {
404 newest_minor_supported,
405 unrecognized_protocol_options,
406 } => {
407 if saw_auth_ok {
408 return Err(PgError::Protocol(
409 "Received NegotiateProtocolVersion after AuthenticationOk".to_string(),
410 ));
411 }
412 let negotiated = if let Ok(minor) = u16::try_from(newest_minor_supported) {
413 minor
414 } else {
415 let packed = u32::try_from(newest_minor_supported).map_err(|_| {
416 PgError::Protocol(format!(
417 "Invalid NegotiateProtocolVersion newest_minor_supported: {}",
418 newest_minor_supported
419 ))
420 })?;
421 let major = (packed >> 16) as u16;
422 let minor = (packed & 0xFFFF) as u16;
423 if major != 3 {
424 return Err(PgError::Protocol(format!(
425 "Invalid NegotiateProtocolVersion newest_minor_supported: {}",
426 newest_minor_supported
427 )));
428 }
429 minor
430 };
431 if negotiated > self.requested_protocol_minor {
432 return Err(PgError::Protocol(format!(
433 "Server negotiated protocol minor {} above requested {}",
434 negotiated, self.requested_protocol_minor
435 )));
436 }
437 self.negotiated_protocol_minor = negotiated;
438 if !unrecognized_protocol_options.is_empty() {
439 tracing::debug!(
440 negotiated_minor = negotiated,
441 unrecognized_count = unrecognized_protocol_options.len(),
442 "startup_negotiate_protocol_version"
443 );
444 }
445 }
446 BackendMessage::BackendKeyData {
447 process_id,
448 secret_key,
449 } => {
450 if !saw_auth_ok {
451 return Err(PgError::Protocol(
452 "Received BackendKeyData before AuthenticationOk".to_string(),
453 ));
454 }
455 self.process_id = process_id;
456 self.cancel_key_bytes = secret_key;
457 }
458 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
459 | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
460 | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
461 if !saw_auth_ok {
462 return Err(PgError::Protocol(
463 "Startup completed without AuthenticationOk".to_string(),
464 ));
465 }
466 return Ok(());
467 }
468 BackendMessage::ErrorResponse(err) => {
469 return Err(PgError::Connection(err.message));
470 }
471 BackendMessage::NoticeResponse(_) => {}
472 _ => {
473 return Err(PgError::Protocol(
474 "Unexpected backend message during startup".to_string(),
475 ));
476 }
477 }
478 }
479 }
480
481 fn tls_server_end_point_channel_binding(&self) -> Option<Vec<u8>> {
486 let PgStream::Tls(tls) = &self.stream else {
487 return None;
488 };
489
490 let (_, conn) = tls.get_ref();
491 let certs = conn.peer_certificates()?;
492 let leaf_cert = certs.first()?;
493
494 let mut hasher = Sha256::new();
495 hasher.update(leaf_cert.as_ref());
496 Some(hasher.finalize().to_vec())
497 }
498
499 pub async fn close(mut self) -> PgResult<()> {
502 use crate::protocol::PgEncoder;
503
504 let terminate = PgEncoder::encode_terminate();
506 self.write_all_with_timeout(&terminate, "stream write")
507 .await?;
508 self.flush_with_timeout("stream flush").await?;
509
510 Ok(())
511 }
512
513 pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
519
520 pub(crate) fn evict_prepared_if_full(&mut self) {
526 if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
527 if let Some((evicted_hash, evicted_name)) = self.stmt_cache.pop_lru() {
529 self.prepared_statements.remove(&evicted_name);
530 self.column_info_cache.remove(&evicted_hash);
531 self.pending_statement_closes.push(evicted_name);
532 } else {
533 if let Some(key) = self.prepared_statements.keys().next().cloned() {
537 self.prepared_statements.remove(&key);
538 self.pending_statement_closes.push(key);
539 }
540 }
541 }
542 }
543
544 pub(crate) fn clear_prepared_statement_state(&mut self) {
549 self.stmt_cache.clear();
550 self.prepared_statements.clear();
551 self.column_info_cache.clear();
552 self.pending_statement_closes.clear();
553 }
554}