qail_pg/driver/connection/
startup.rs1use super::helpers::{generate_gss_token, md5_password_message, select_scram_mechanism};
4use super::types::{GSS_SESSION_COUNTER, PgConnection, StartupAuthFlow};
5use crate::driver::stream::PgStream;
6use crate::driver::{
7 AuthSettings, EnterpriseAuthMechanism, GssTokenProvider, GssTokenProviderEx, PgError, PgResult,
8};
9use crate::protocol::{BackendMessage, FrontendMessage, ScramClient, TransactionStatus};
10use sha2::{Digest, Sha256};
11use std::sync::atomic::Ordering;
12
13impl PgConnection {
14 pub(super) async fn handle_startup(
16 &mut self,
17 user: &str,
18 password: Option<&str>,
19 auth_settings: AuthSettings,
20 gss_token_provider: Option<GssTokenProvider>,
21 gss_token_provider_ex: Option<GssTokenProviderEx>,
22 ) -> PgResult<()> {
23 let mut scram_client: Option<ScramClient> = None;
24 let mut startup_auth_flow: Option<StartupAuthFlow> = None;
25 let mut saw_auth_ok = false;
26 let gss_session_id = GSS_SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
27 let mut gss_roundtrips: u32 = 0;
28 const MAX_GSS_ROUNDTRIPS: u32 = 32;
29
30 loop {
31 let msg = self.recv().await?;
32 if saw_auth_ok
33 && matches!(
34 &msg,
35 BackendMessage::AuthenticationOk
36 | BackendMessage::AuthenticationKerberosV5
37 | BackendMessage::AuthenticationGSS
38 | BackendMessage::AuthenticationGSSContinue(_)
39 | BackendMessage::AuthenticationSSPI
40 | BackendMessage::AuthenticationCleartextPassword
41 | BackendMessage::AuthenticationMD5Password(_)
42 | BackendMessage::AuthenticationSASL(_)
43 | BackendMessage::AuthenticationSASLContinue(_)
44 | BackendMessage::AuthenticationSASLFinal(_)
45 )
46 {
47 return Err(PgError::Protocol(
48 "Received authentication challenge after AuthenticationOk".to_string(),
49 ));
50 }
51 match msg {
52 BackendMessage::AuthenticationOk => {
53 if let Some(StartupAuthFlow::Scram {
54 server_final_seen: false,
55 }) = startup_auth_flow
56 {
57 return Err(PgError::Protocol(
58 "Received AuthenticationOk before AuthenticationSASLFinal".to_string(),
59 ));
60 }
61 saw_auth_ok = true;
62 }
63 BackendMessage::AuthenticationKerberosV5 => {
64 if let Some(flow) = startup_auth_flow {
65 return Err(PgError::Protocol(format!(
66 "Received AuthenticationKerberosV5 while {} authentication is in progress",
67 flow.label()
68 )));
69 }
70 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
71 mechanism: EnterpriseAuthMechanism::KerberosV5,
72 });
73
74 if !auth_settings.allow_kerberos_v5 {
75 return Err(PgError::Auth(
76 "Server requested Kerberos V5 authentication, but Kerberos V5 is disabled by AuthSettings".to_string(),
77 ));
78 }
79
80 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
81 return Err(PgError::Auth(
82 "Kerberos V5 authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
83 ));
84 }
85
86 let token = generate_gss_token(
87 gss_session_id,
88 EnterpriseAuthMechanism::KerberosV5,
89 None,
90 gss_token_provider,
91 gss_token_provider_ex.as_ref(),
92 )
93 .map_err(|e| {
94 PgError::Auth(format!("Kerberos V5 token generation failed: {}", e))
95 })?;
96
97 self.send(FrontendMessage::GSSResponse(token)).await?;
98 }
99 BackendMessage::AuthenticationGSS => {
100 if let Some(flow) = startup_auth_flow {
101 return Err(PgError::Protocol(format!(
102 "Received AuthenticationGSS while {} authentication is in progress",
103 flow.label()
104 )));
105 }
106 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
107 mechanism: EnterpriseAuthMechanism::GssApi,
108 });
109
110 if !auth_settings.allow_gssapi {
111 return Err(PgError::Auth(
112 "Server requested GSSAPI authentication, but GSSAPI is disabled by AuthSettings".to_string(),
113 ));
114 }
115
116 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
117 return Err(PgError::Auth(
118 "GSSAPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
119 ));
120 }
121
122 let token = generate_gss_token(
123 gss_session_id,
124 EnterpriseAuthMechanism::GssApi,
125 None,
126 gss_token_provider,
127 gss_token_provider_ex.as_ref(),
128 )
129 .map_err(|e| {
130 PgError::Auth(format!("GSSAPI initial token generation failed: {}", e))
131 })?;
132
133 self.send(FrontendMessage::GSSResponse(token)).await?;
134 }
135 BackendMessage::AuthenticationSSPI => {
136 if let Some(flow) = startup_auth_flow {
137 return Err(PgError::Protocol(format!(
138 "Received AuthenticationSSPI while {} authentication is in progress",
139 flow.label()
140 )));
141 }
142 startup_auth_flow = Some(StartupAuthFlow::EnterpriseGss {
143 mechanism: EnterpriseAuthMechanism::Sspi,
144 });
145
146 if !auth_settings.allow_sspi {
147 return Err(PgError::Auth(
148 "Server requested SSPI authentication, but SSPI is disabled by AuthSettings".to_string(),
149 ));
150 }
151
152 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
153 return Err(PgError::Auth(
154 "SSPI authentication requested but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
155 ));
156 }
157
158 let token = generate_gss_token(
159 gss_session_id,
160 EnterpriseAuthMechanism::Sspi,
161 None,
162 gss_token_provider,
163 gss_token_provider_ex.as_ref(),
164 )
165 .map_err(|e| {
166 PgError::Auth(format!("SSPI initial token generation failed: {}", e))
167 })?;
168
169 self.send(FrontendMessage::GSSResponse(token)).await?;
170 }
171 BackendMessage::AuthenticationGSSContinue(server_token) => {
172 gss_roundtrips += 1;
173 if gss_roundtrips > MAX_GSS_ROUNDTRIPS {
174 return Err(PgError::Auth(format!(
175 "GSS handshake exceeded {} roundtrips — aborting",
176 MAX_GSS_ROUNDTRIPS
177 )));
178 }
179
180 let mechanism = match startup_auth_flow {
181 Some(StartupAuthFlow::EnterpriseGss { mechanism }) => mechanism,
182 Some(flow) => {
183 return Err(PgError::Protocol(format!(
184 "Received AuthenticationGSSContinue while {} authentication is in progress",
185 flow.label()
186 )));
187 }
188 None => {
189 return Err(PgError::Auth(
190 "Received GSSContinue without AuthenticationGSS/SSPI/KerberosV5 init"
191 .to_string(),
192 ));
193 }
194 };
195
196 if gss_token_provider.is_none() && gss_token_provider_ex.is_none() {
197 return Err(PgError::Auth(
198 "Received GSSContinue but no GSS token provider is configured. Set ConnectOptions.gss_token_provider or ConnectOptions.gss_token_provider_ex.".to_string(),
199 ));
200 }
201
202 let token = generate_gss_token(
203 gss_session_id,
204 mechanism,
205 Some(&server_token),
206 gss_token_provider,
207 gss_token_provider_ex.as_ref(),
208 )
209 .map_err(|e| {
210 PgError::Auth(format!("GSS continue token generation failed: {}", e))
211 })?;
212
213 if !token.is_empty() {
220 self.send(FrontendMessage::GSSResponse(token)).await?;
221 }
222 }
223 BackendMessage::AuthenticationCleartextPassword => {
224 if let Some(flow) = startup_auth_flow {
225 return Err(PgError::Protocol(format!(
226 "Received AuthenticationCleartextPassword while {} authentication is in progress",
227 flow.label()
228 )));
229 }
230 startup_auth_flow = Some(StartupAuthFlow::CleartextPassword);
231
232 if !auth_settings.allow_cleartext_password {
233 return Err(PgError::Auth(
234 "Server requested cleartext authentication, but cleartext is disabled by AuthSettings"
235 .to_string(),
236 ));
237 }
238 let password = password.ok_or_else(|| {
239 PgError::Auth("Password required for cleartext authentication".to_string())
240 })?;
241 self.send(FrontendMessage::PasswordMessage(password.to_string()))
242 .await?;
243 }
244 BackendMessage::AuthenticationMD5Password(salt) => {
245 if let Some(flow) = startup_auth_flow {
246 return Err(PgError::Protocol(format!(
247 "Received AuthenticationMD5Password while {} authentication is in progress",
248 flow.label()
249 )));
250 }
251 startup_auth_flow = Some(StartupAuthFlow::Md5Password);
252
253 if !auth_settings.allow_md5_password {
254 return Err(PgError::Auth(
255 "Server requested MD5 authentication, but MD5 is disabled by AuthSettings"
256 .to_string(),
257 ));
258 }
259 let password = password.ok_or_else(|| {
260 PgError::Auth("Password required for MD5 authentication".to_string())
261 })?;
262 let md5_password = md5_password_message(user, password, salt);
263 self.send(FrontendMessage::PasswordMessage(md5_password))
264 .await?;
265 }
266 BackendMessage::AuthenticationSASL(mechanisms) => {
267 if let Some(flow) = startup_auth_flow {
268 return Err(PgError::Protocol(format!(
269 "Received AuthenticationSASL while {} authentication is in progress",
270 flow.label()
271 )));
272 }
273 startup_auth_flow = Some(StartupAuthFlow::Scram {
274 server_final_seen: false,
275 });
276
277 if !auth_settings.allow_scram_sha_256 {
278 return Err(PgError::Auth(
279 "Server requested SCRAM authentication, but SCRAM is disabled by AuthSettings"
280 .to_string(),
281 ));
282 }
283 let password = password.ok_or_else(|| {
284 PgError::Auth("Password required for SCRAM authentication".to_string())
285 })?;
286
287 let tls_binding = self.tls_server_end_point_channel_binding();
288 let (mechanism, channel_binding_data) = select_scram_mechanism(
289 &mechanisms,
290 tls_binding,
291 auth_settings.channel_binding,
292 )
293 .map_err(PgError::Auth)?;
294
295 let client = if let Some(binding_data) = channel_binding_data {
296 ScramClient::new_with_tls_server_end_point(user, password, binding_data)
297 } else {
298 ScramClient::new(user, password)
299 };
300 let first_message = client.client_first_message();
301
302 self.send(FrontendMessage::SASLInitialResponse {
303 mechanism,
304 data: first_message,
305 })
306 .await?;
307
308 scram_client = Some(client);
309 }
310 BackendMessage::AuthenticationSASLContinue(server_data) => {
311 match startup_auth_flow {
312 Some(StartupAuthFlow::Scram {
313 server_final_seen: false,
314 }) => {}
315 Some(StartupAuthFlow::Scram {
316 server_final_seen: true,
317 }) => {
318 return Err(PgError::Protocol(
319 "Received AuthenticationSASLContinue after AuthenticationSASLFinal"
320 .to_string(),
321 ));
322 }
323 Some(flow) => {
324 return Err(PgError::Protocol(format!(
325 "Received AuthenticationSASLContinue while {} authentication is in progress",
326 flow.label()
327 )));
328 }
329 None => {
330 return Err(PgError::Auth(
331 "Received SASL Continue without SASL init".to_string(),
332 ));
333 }
334 }
335
336 let client = scram_client.as_mut().ok_or_else(|| {
337 PgError::Auth("Received SASL Continue without SASL init".to_string())
338 })?;
339
340 let final_message = client
341 .process_server_first(&server_data)
342 .map_err(|e| PgError::Auth(format!("SCRAM error: {}", e)))?;
343
344 self.send(FrontendMessage::SASLResponse(final_message))
345 .await?;
346 }
347 BackendMessage::AuthenticationSASLFinal(server_signature) => {
348 match startup_auth_flow {
349 Some(StartupAuthFlow::Scram {
350 server_final_seen: false,
351 }) => {
352 startup_auth_flow = Some(StartupAuthFlow::Scram {
353 server_final_seen: true,
354 });
355 }
356 Some(StartupAuthFlow::Scram {
357 server_final_seen: true,
358 }) => {
359 return Err(PgError::Protocol(
360 "Received duplicate AuthenticationSASLFinal".to_string(),
361 ));
362 }
363 Some(flow) => {
364 return Err(PgError::Protocol(format!(
365 "Received AuthenticationSASLFinal while {} authentication is in progress",
366 flow.label()
367 )));
368 }
369 None => {
370 return Err(PgError::Auth(
371 "Received SASL Final without SASL init".to_string(),
372 ));
373 }
374 }
375
376 let client = scram_client.as_ref().ok_or_else(|| {
377 PgError::Auth("Received SASL Final without SASL init".to_string())
378 })?;
379 client
380 .verify_server_final(&server_signature)
381 .map_err(|e| PgError::Auth(format!("Server verification failed: {}", e)))?;
382 }
383 BackendMessage::ParameterStatus { .. } => {
384 if !saw_auth_ok {
385 return Err(PgError::Protocol(
386 "Received ParameterStatus before AuthenticationOk".to_string(),
387 ));
388 }
389 }
390 BackendMessage::BackendKeyData {
391 process_id,
392 secret_key,
393 } => {
394 if !saw_auth_ok {
395 return Err(PgError::Protocol(
396 "Received BackendKeyData before AuthenticationOk".to_string(),
397 ));
398 }
399 self.process_id = process_id;
400 self.secret_key = secret_key;
401 }
402 BackendMessage::ReadyForQuery(TransactionStatus::Idle)
403 | BackendMessage::ReadyForQuery(TransactionStatus::InBlock)
404 | BackendMessage::ReadyForQuery(TransactionStatus::Failed) => {
405 if !saw_auth_ok {
406 return Err(PgError::Protocol(
407 "Startup completed without AuthenticationOk".to_string(),
408 ));
409 }
410 return Ok(());
411 }
412 BackendMessage::ErrorResponse(err) => {
413 return Err(PgError::Connection(err.message));
414 }
415 BackendMessage::NoticeResponse(_) => {}
416 _ => {
417 return Err(PgError::Protocol(
418 "Unexpected backend message during startup".to_string(),
419 ));
420 }
421 }
422 }
423 }
424
425 fn tls_server_end_point_channel_binding(&self) -> Option<Vec<u8>> {
430 let PgStream::Tls(tls) = &self.stream else {
431 return None;
432 };
433
434 let (_, conn) = tls.get_ref();
435 let certs = conn.peer_certificates()?;
436 let leaf_cert = certs.first()?;
437
438 let mut hasher = Sha256::new();
439 hasher.update(leaf_cert.as_ref());
440 Some(hasher.finalize().to_vec())
441 }
442
443 pub async fn close(mut self) -> PgResult<()> {
446 use crate::protocol::PgEncoder;
447
448 let terminate = PgEncoder::encode_terminate();
450 self.write_all_with_timeout(&terminate, "stream write")
451 .await?;
452 self.flush_with_timeout("stream flush").await?;
453
454 Ok(())
455 }
456
457 pub(crate) const MAX_PREPARED_PER_CONN: usize = 128;
463
464 pub(crate) fn evict_prepared_if_full(&mut self) {
470 if self.prepared_statements.len() >= Self::MAX_PREPARED_PER_CONN {
471 if let Some((evicted_hash, evicted_name)) = self.stmt_cache.pop_lru() {
473 self.prepared_statements.remove(&evicted_name);
474 self.column_info_cache.remove(&evicted_hash);
475 self.pending_statement_closes.push(evicted_name);
476 } else {
477 if let Some(key) = self.prepared_statements.keys().next().cloned() {
481 self.prepared_statements.remove(&key);
482 self.pending_statement_closes.push(key);
483 }
484 }
485 }
486 }
487
488 pub(crate) fn clear_prepared_statement_state(&mut self) {
493 self.stmt_cache.clear();
494 self.prepared_statements.clear();
495 self.column_info_cache.clear();
496 self.pending_statement_closes.clear();
497 }
498}