qail_pg/driver/connection/
startup.rs1use super::helpers::{generate_gss_token, md5_password_message, select_scram_mechanism};
4use super::types::{GSS_SESSION_COUNTER, PgConnection, StartupAuthFlow};
5use crate::driver::stream::PgStream;
6use crate::driver::{
7 AuthSettings, EnterpriseAuthMechanism, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
8};
9use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
10use sha2::{Digest, Sha256};
11use std::sync::atomic::Ordering;
12
13impl PgConnection {
14 pub(super) async fn handle_startup(
16 &mut self,
17 user: &str,
18 password: Option<&str>,
19 auth_settings: AuthSettings,
20 gss_token_provider: Option<GssTokenProvider>,
21 gss_token_provider_ex: Option<GssTokenProviderEx>,
22 ) -> PgResult<()> {
23 let mut scram_client: Option<ScramClient> = None;
24 let mut startup_auth_flow: Option<StartupAuthFlow> = None;
25 let mut saw_auth_ok = false;
26 let gss_session_id = GSS_SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
27 let mut gss_roundtrips: u32 = 0;
28 const MAX_GSS_ROUNDTRIPS: u32 = 32;
29
30 loop {
31 let msg = self.recv().await?;
32 if saw_auth_ok
33 && matches!(
34 &msg,
35 BackendMessage::AuthenticationOk
36 | BackendMessage::AuthenticationKerberosV5
37 | BackendMessage::AuthenticationGSS
38 | BackendMessage::AuthenticationSCMCredential
39 | BackendMessage::AuthenticationGSSContinue(_)
40 | BackendMessage::AuthenticationSSPI
41 | BackendMessage::AuthenticationCleartextPassword
42 | BackendMessage::AuthenticationMD5Password(_)
43 | BackendMessage::AuthenticationSASL(_)
44 | BackendMessage::AuthenticationSASLContinue(_)
45 | BackendMessage::AuthenticationSASLFinal(_)
46 )
47 {
48 return Err(PgError::Protocol(
49 "Received authentication challenge after AuthenticationOk".to_string(),
50 ));
51 }
52 match msg {
53 BackendMessage::AuthenticationOk => {
54 if let Some(StartupAuthFlow::Scram {
55 server_final_seen: false,
56 }) = startup_auth_flow
57 {
58 return Err(PgError::Protocol(
59 "Received AuthenticationOk before AuthenticationSASLFinal".to_string(),
60 ));
61 }
62 saw_auth_ok = true;
63 }
64 BackendMessage::AuthenticationKerberosV5 => {
65 if let Some(flow) = startup_auth_flow {
66 return Err(PgError::Protocol(format!(
67 "Received AuthenticationKerberosV5 while {} authentication is in progress",
68 flow.label()
69 )));
70 }
71 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
72 mechanism: EnterpriseAuthMechanism::KerberosV5,
73 });
74
75 if !auth_settings.allow_kerberos_v5 {
76 return Err(PgError::Auth(
77 "Server requested Kerberos V5 authentication, but Kerberos V5 is disabled by AuthSettings".to_string(),
78 ));
79 }
80
81 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
82 return Err(PgError::Auth(
83 "Kerberos V5 authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
84 ));
85 }
86
87 let token = generate_gss_token(
88 gss_session_id,
89 EnterpriseAuthMechanism::KerberosV5,
90 None,
91 gss_token_provider,
92 gss_token_provider_ex.as_ref(),
93 )
94 .map_err(|e| {
95 PgError::Auth(format!("Kerberos V5 token generation failed: {}", e))
96 })?;
97
98 self.send(FrontendMessage::GSSResponse(token)).await?;
99 }
100 BackendMessage::AuthenticationGSS => {
101 if let Some(flow) = startup_auth_flow {
102 return Err(PgError::Protocol(format!(
103 "Received AuthenticationGSS while {} authentication is in progress",
104 flow.label()
105 )));
106 }
107 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
108 mechanism: EnterpriseAuthMechanism::GssApi,
109 });
110
111 if !auth_settings.allow_gssapi {
112 return Err(PgError::Auth(
113 "Server requested GSSAPI authentication, but GSSAPI is disabled by AuthSettings".to_string(),
114 ));
115 }
116
117 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
118 return Err(PgError::Auth(
119 "GSSAPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
120 ));
121 }
122
123 let token = generate_gss_token(
124 gss_session_id,
125 EnterpriseAuthMechanism::GssApi,
126 None,
127 gss_token_provider,
128 gss_token_provider_ex.as_ref(),
129 )
130 .map_err(|e| {
131 PgError::Auth(format!("GSSAPI initial token generation failed: {}", e))
132 })?;
133
134 self.send(FrontendMessage::GSSResponse(token)).await?;
135 }
136 BackendMessage::AuthenticationSCMCredential => {
137 if let Some(flow) = startup_auth_flow {
138 return Err(PgError::Protocol(format!(
139 "Received AuthenticationSCMCredential while {} authentication is in progress",
140 flow.label()
141 )));
142 }
143 return Err(PgError::Auth(
144 "Server requested SCM credential authentication (auth code 6). This driver currently does not support Unix-socket credential passing; use SCRAM, GSS/SSPI, or password auth for this connection."
145 .to_string(),
146 ));
147 }
148 BackendMessage::AuthenticationSSPI => {
149 if let Some(flow) = startup_auth_flow {
150 return Err(PgError::Protocol(format!(
151 "Received AuthenticationSSPI while {} authentication is in progress",
152 flow.label()
153 )));
154 }
155 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
156 mechanism: EnterpriseAuthMechanism::Sspi,
157 });
158
159 if !auth_settings.allow_sspi {
160 return Err(PgError::Auth(
161 "Server requested SSPI authentication, but SSPI is disabled by AuthSettings".to_string(),
162 ));
163 }
164
165 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
166 return Err(PgError::Auth(
167 "SSPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
168 ));
169 }
170
171 let token = generate_gss_token(
172 gss_session_id,
173 EnterpriseAuthMechanism::Sspi,
174 None,
175 gss_token_provider,
176 gss_token_provider_ex.as_ref(),
177 )
178 .map_err(|e| {
179 PgError::Auth(format!("SSPI initial token generation failed: {}", e))
180 })?;
181
182 self.send(FrontendMessage::GSSResponse(token)).await?;
183 }
184 BackendMessage::AuthenticationGSSContinue(server_token) => {
185 gss_roundtrips += 1;
186 if gss_roundtrips > MAX_GSS_ROUNDTRIPS {
187 return Err(PgError::Auth(format!(
188 "GSS handshake exceeded {} roundtrips — aborting",
189 MAX_GSS_ROUNDTRIPS
190 )));
191 }
192
193 let mechanism = match startup_auth_flow {
194 Some(StartupAuthFlow::EnterpriseGss { mechanism }) => mechanism,
195 Some(flow) => {
196 return Err(PgError::Protocol(format!(
197 "Received AuthenticationGSSContinue while {} authentication is in progress",
198 flow.label()
199 )));
200 }
201 None => {
202 return Err(PgError::Auth(
203 "Received GSSContinue without AuthenticationGSS/SSPI/KerberosV5 init"
204 .to_string(),
205 ));
206 }
207 };
208
209 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
210 return Err(PgError::Auth(
211 "Received GSSContinue but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
212 ));
213 }
214
215 let token = generate_gss_token(
216 gss_session_id,
217 mechanism,
218 Some(&server_token),
219 gss_token_provider,
220 gss_token_provider_ex.as_ref(),
221 )
222 .map_err(|e| {
223 PgError::Auth(format!("GSS continue token generation failed: {}", e))
224 })?;
225
226 if !token.is_empty() {
233 self.send(FrontendMessage::GSSResponse(token)).await?;
234 }
235 }
236 BackendMessage::AuthenticationCleartextPassword => {
237 if let Some(flow) = startup_auth_flow {
238 return Err(PgError::Protocol(format!(
239 "Received AuthenticationCleartextPassword while {} authentication is in progress",
240 flow.label()
241 )));
242 }
243 startup_auth_flow = Some(StartupAuthFlow::CleartextPassword);
244
245 if !auth_settings.allow_cleartext_password {
246 return Err(PgError::Auth(
247 "Server requested cleartext authentication, but cleartext is disabled by AuthSettings"
248 .to_string(),
249 ));
250 }
251 let password = password.ok_or_else(|| {
252 PgError::Auth("Password required for cleartext authentication".to_string())
253 })?;
254 self.send(FrontendMessage::PasswordMessage(password.to_string()))
255 .await?;
256 }
257 BackendMessage::AuthenticationMD5Password(salt) => {
258 if let Some(flow) = startup_auth_flow {
259 return Err(PgError::Protocol(format!(
260 "Received AuthenticationMD5Password while {} authentication is in progress",
261 flow.label()
262 )));
263 }
264 startup_auth_flow = Some(StartupAuthFlow::Md5Password);
265
266 if !auth_settings.allow_md5_password {
267 return Err(PgError::Auth(
268 "Server requested MD5 authentication, but MD5 is disabled by AuthSettings"
269 .to_string(),
270 ));
271 }
272 let password = password.ok_or_else(|| {
273 PgError::Auth("Password required for MD5 authentication".to_string())
274 })?;
275 let md5_password = md5_password_message(user, password, salt);
276 self.send(FrontendMessage::PasswordMessage(md5_password))
277 .await?;
278 }
279 BackendMessage::AuthenticationSASL(mechanisms) => {
280 if let Some(flow) = startup_auth_flow {
281 return Err(PgError::Protocol(format!(
282 "Received AuthenticationSASL while {} authentication is in progress",
283 flow.label()
284 )));
285 }
286 startup_auth_flow = Some(StartupAuthFlow::Scram {
287 server_final_seen: false,
288 });
289
290 if !auth_settings.allow_scram_sha_256 {
291 return Err(PgError::Auth(
292 "Server requested SCRAM authentication, but SCRAM is disabled by AuthSettings"
293 .to_string(),
294 ));
295 }
296 let password = password.ok_or_else(|| {
297 PgError::Auth("Password required for SCRAM authentication".to_string())
298 })?;
299
300 let tls_binding = self.tls_server_end_point_channel_binding();
301 let (mechanism, channel_binding_data) = select_scram_mechanism(
302 &mechanisms,
303 tls_binding,
304 auth_settings.channel_binding,
305 )
306 .map_err(PgError::Auth)?;
307
308 let client = if let Some(binding_data) = channel_binding_data {
309 ScramClient::new_with_tls_server_end_point(user, password, binding_data)
310 } else {
311 ScramClient::new(user, password)
312 };
313 let first_message = client.client_first_message();
314
315 self.send(FrontendMessage::SASLInitialResponse {
316 mechanism,
317 data: first_message,
318 })
319 .await?;
320
321 scram_client = Some(client);
322 }
323 BackendMessage::AuthenticationSASLContinue(server_data) => {
324 match startup_auth_flow {
325 Some(StartupAuthFlow::Scram {
326 server_final_seen: false,
327 }) => {}
328 Some(StartupAuthFlow::Scram {
329 server_final_seen: true,
330 }) => {
331 return Err(PgError::Protocol(
332 "Received AuthenticationSASLContinue after AuthenticationSASLFinal"
333 .to_string(),
334 ));
335 }
336 Some(flow) => {
337 return Err(PgError::Protocol(format!(
338 "Received AuthenticationSASLContinue while {} authentication is in progress",
339 flow.label()
340 )));
341 }
342 None => {
343 return Err(PgError::Auth(
344 "Received SASL Continue without SASL init".to_string(),
345 ));
346 }
347 }
348
349 let client = scram_client.as_mut().ok_or_else(|| {
350 PgError::Auth("Received SASL Continue without SASL init".to_string())
351 })?;
352
353 let final_message = client
354 .process_server_first(&server_data)
355 .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
356
357 self.send(FrontendMessage::SASLResponse(final_message))
358 .await?;
359 }
360 BackendMessage::AuthenticationSASLFinal(server_signature) => {
361 match startup_auth_flow {
362 Some(StartupAuthFlow::Scram {
363 server_final_seen: false,
364 }) => {
365 startup_auth_flow = Some(StartupAuthFlow::Scram {
366 server_final_seen: true,
367 });
368 }
369 Some(StartupAuthFlow::Scram {
370 server_final_seen: true,
371 }) => {
372 return Err(PgError::Protocol(
373 "Received duplicate AuthenticationSASLFinal".to_string(),
374 ));
375 }
376 Some(flow) => {
377 return Err(PgError::Protocol(format!(
378 "Received AuthenticationSASLFinal while {} authentication is in progress",
379 flow.label()
380 )));
381 }
382 None => {
383 return Err(PgError::Auth(
384 "Received SASL Final without SASL init".to_string(),
385 ));
386 }
387 }
388
389 let client = scram_client.as_ref().ok_or_else(|| {
390 PgError::Auth("Received SASL Final without SASL init".to_string())
391 })?;
392 client
393 .verify_server_final(&server_signature)
394 .map_err(|e| PgError::Auth(format!("Server verification failed: {}", e)))?;
395 }
396 BackendMessage::ParameterStatus { .. } => {
397 if !saw_auth_ok {
398 return Err(PgError::Protocol(
399 "Received ParameterStatus before AuthenticationOk".to_string(),
400 ));
401 }
402 }
403 BackendMessage::BackendKeyData {
404 process_id,
405 secret_key,
406 } => {
407 if !saw_auth_ok {
408 return Err(PgError::Protocol(
409 "Received BackendKeyData before AuthenticationOk".to_string(),
410 ));
411 }
412 self.process_id = process_id;
413 self.secret_key = secret_key;
414 }
415 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
416 | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
417 | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
418 if !saw_auth_ok {
419 return Err(PgError::Protocol(
420 "Startup completed without AuthenticationOk".to_string(),
421 ));
422 }
423 return Ok(());
424 }
425 BackendMessage::ErrorResponse(err) => {
426 return Err(PgError::Connection(err.message));
427 }
428 BackendMessage::NoticeResponse(_) => {}
429 _ => {
430 return Err(PgError::Protocol(
431 "Unexpected backend message during startup".to_string(),
432 ));
433 }
434 }
435 }
436 }
437
438 fn tls_server_end_point_channel_binding(&self) -> Option<Vec<u8>> {
443 let PgStream::Tls(tls) = &self.stream else {
444 return None;
445 };
446
447 let (_, conn) = tls.get_ref();
448 let certs = conn.peer_certificates()?;
449 let leaf_cert = certs.first()?;
450
451 let mut hasher = Sha256::new();
452 hasher.update(leaf_cert.as_ref());
453 Some(hasher.finalize().to_vec())
454 }
455
456 pub async fn close(mut self) -> PgResult<()> {
459 use crate::protocol::PgEncoder;
460
461 let terminate = PgEncoder::encode_terminate();
463 self.write_all_with_timeout(&terminate, "stream write")
464 .await?;
465 self.flush_with_timeout("stream flush").await?;
466
467 Ok(())
468 }
469
470 pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
476
477 pub(crate) fn evict_prepared_if_full(&mut self) {
483 if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
484 if let Some((evicted_hash, evicted_name)) = self.stmt_cache.pop_lru() {
486 self.prepared_statements.remove(&evicted_name);
487 self.column_info_cache.remove(&evicted_hash);
488 self.pending_statement_closes.push(evicted_name);
489 } else {
490 if let Some(key) = self.prepared_statements.keys().next().cloned() {
494 self.prepared_statements.remove(&key);
495 self.pending_statement_closes.push(key);
496 }
497 }
498 }
499 }
500
501 pub(crate) fn clear_prepared_statement_state(&mut self) {
506 self.stmt_cache.clear();
507 self.prepared_statements.clear();
508 self.column_info_cache.clear();
509 self.pending_statement_closes.clear();
510 }
511}