qail_pg/driver/connection/
helpers.rs1#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::types::CONNECT_BACKEND_IO_URING;
5use super::types::{CONNECT_BACKEND_TOKIO, PgConnection};
6use crate::driver::stream::PgStream;
7use crate::driver::{
8 EnterpriseAuthMechanism, GssTokenProvider, GssTokenProviderEx, GssTokenRequest, PgError,
9 PgResult, ScramChannelBindingMode,
10};
11
12pub(super) fn generate_gss_token(
13 session_id: u64,
14 mechanism: EnterpriseAuthMechanism,
15 server_token: Option<&[u8]>,
16 legacy_provider: Option<GssTokenProvider>,
17 stateful_provider: Option<&GssTokenProviderEx>,
18) -> Result<Vec<u8>, String> {
19 if let Some(provider) = stateful_provider {
20 return provider(GssTokenRequest {
21 session_id,
22 mechanism,
23 server_token,
24 });
25 }
26
27 if let Some(provider) = legacy_provider {
28 return provider(mechanism, server_token);
29 }
30
31 Err("No GSS token provider configured".to_string())
32}
33
34pub(super) fn plain_connect_attempt_backend() -> &'static str {
35 #[cfg(all(target_os = "linux", feature = "io_uring"))]
36 {
37 if should_try_uring_plain() {
38 return CONNECT_BACKEND_IO_URING;
39 }
40 }
41 CONNECT_BACKEND_TOKIO
42}
43
44pub(super) fn connect_backend_for_stream(stream: &PgStream) -> &'static str {
45 match stream {
46 PgStream::Tcp(_) => CONNECT_BACKEND_TOKIO,
47 #[cfg(all(target_os = "linux", feature = "io_uring"))]
48 PgStream::Uring(_) => CONNECT_BACKEND_IO_URING,
49 PgStream::Tls(_) => CONNECT_BACKEND_TOKIO,
50 #[cfg(unix)]
51 PgStream::Unix(_) => CONNECT_BACKEND_TOKIO,
52 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
53 PgStream::GssEnc(_) => CONNECT_BACKEND_TOKIO,
54 }
55}
56
57pub(super) fn connect_error_kind(error: &PgError) -> &'static str {
58 match error {
59 PgError::Connection(_) => "connection",
60 PgError::Protocol(_) => "protocol",
61 PgError::Auth(_) => "auth",
62 PgError::Query(_) | PgError::QueryServer(_) => "query",
63 PgError::NoRows => "no_rows",
64 PgError::Io(_) => "io",
65 PgError::Encode(_) => "encode",
66 PgError::Timeout(_) => "timeout",
67 PgError::PoolExhausted { .. } => "pool_exhausted",
68 PgError::PoolClosed => "pool_closed",
69 }
70}
71
72pub(super) fn record_connect_attempt(transport: &'static str, backend: &'static str) {
73 metrics::counter!(
74 "qail_pg_connect_attempt_total",
75 "transport" => transport,
76 "backend" => backend
77 )
78 .increment(1);
79}
80
81pub(super) fn record_connect_result(
82 transport: &'static str,
83 backend: &'static str,
84 result: &PgResult<PgConnection>,
85 elapsed: std::time::Duration,
86) {
87 let outcome = if result.is_ok() { "success" } else { "error" };
88 metrics::histogram!(
89 "qail_pg_connect_duration_seconds",
90 "transport" => transport,
91 "backend" => backend,
92 "outcome" => outcome
93 )
94 .record(elapsed.as_secs_f64());
95
96 if let Err(error) = result {
97 metrics::counter!(
98 "qail_pg_connect_failure_total",
99 "transport" => transport,
100 "backend" => backend,
101 "error_kind" => connect_error_kind(error)
102 )
103 .increment(1);
104 } else {
105 metrics::counter!(
106 "qail_pg_connect_success_total",
107 "transport" => transport,
108 "backend" => backend
109 )
110 .increment(1);
111 }
112}
113
114pub(super) fn select_scram_mechanism(
115 mechanisms: &[String],
116 tls_server_end_point_binding: Option<Vec<u8>>,
117 channel_binding_mode: ScramChannelBindingMode,
118) -> Result<(String, Option<Vec<u8>>), String> {
119 let has_scram = mechanisms.iter().any(|m| m == "SCRAM-SHA-256");
120 let has_scram_plus = mechanisms.iter().any(|m| m == "SCRAM-SHA-256-PLUS");
121
122 match channel_binding_mode {
123 ScramChannelBindingMode::Disable => {
124 if has_scram {
125 return Ok(("SCRAM-SHA-256".to_string(), None));
126 }
127 Err(format!(
128 "channel_binding=disable, but server does not advertise SCRAM-SHA-256. Available: {:?}",
129 mechanisms
130 ))
131 }
132 ScramChannelBindingMode::Prefer => {
133 if has_scram_plus {
134 if let Some(binding) = tls_server_end_point_binding {
135 return Ok(("SCRAM-SHA-256-PLUS".to_string(), Some(binding)));
136 }
137
138 if has_scram {
139 return Ok(("SCRAM-SHA-256".to_string(), None));
140 }
141
142 return Err(
143 "Server requires SCRAM-SHA-256-PLUS but TLS channel binding is unavailable"
144 .to_string(),
145 );
146 }
147
148 if has_scram {
149 return Ok(("SCRAM-SHA-256".to_string(), None));
150 }
151
152 Err(format!(
153 "Server doesn't support SCRAM-SHA-256. Available: {:?}",
154 mechanisms
155 ))
156 }
157 ScramChannelBindingMode::Require => {
158 if !has_scram_plus {
159 return Err(
160 "channel_binding=require, but server does not advertise SCRAM-SHA-256-PLUS"
161 .to_string(),
162 );
163 }
164 let binding = tls_server_end_point_binding.ok_or_else(|| {
165 "channel_binding=require, but TLS channel binding data is unavailable".to_string()
166 })?;
167 Ok(("SCRAM-SHA-256-PLUS".to_string(), Some(binding)))
168 }
169 }
170}
171
172pub(super) fn md5_password_message(user: &str, password: &str, salt: [u8; 4]) -> String {
174 use md5::{Digest, Md5};
175
176 let mut inner = Md5::new();
177 inner.update(password.as_bytes());
178 inner.update(user.as_bytes());
179 let inner_hex = format!("{:x}", inner.finalize());
180
181 let mut outer = Md5::new();
182 outer.update(inner_hex.as_bytes());
183 outer.update(salt);
184 format!("md5{:x}", outer.finalize())
185}
186
187impl Drop for PgConnection {
190 fn drop(&mut self) {
191 let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
194
195 match &mut self.stream {
196 PgStream::Tcp(tcp) => {
197 let _ = tcp.try_write(&terminate);
199 }
200 #[cfg(all(target_os = "linux", feature = "io_uring"))]
201 PgStream::Uring(stream) => {
202 let _ = stream.abort_inflight();
206 }
207 PgStream::Tls(_) => {
208 }
212 #[cfg(unix)]
213 PgStream::Unix(unix) => {
214 let _ = unix.try_write(&terminate);
215 }
216 #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
217 PgStream::GssEnc(_) => {
218 }
220 }
221 }
222}
223
224pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
225 tag.split_whitespace()
226 .last()
227 .and_then(|s| s.parse().ok())
228 .unwrap_or(0)
229}
230
231#[cfg(all(target_os = "linux", feature = "io_uring"))]
232pub(super) fn should_try_uring_plain() -> bool {
233 super::super::io_backend::should_use_uring_plain_transport()
234}