Skip to main content

qail_pg/driver/connection/
helpers.rs

1//! Free helper functions — GSS token gen, metrics, MD5 password, SCRAM selection, Drop.
2
3#[cfg(all(target_os = "linux", feature = "io_uring"))]
4use super::types::CONNECT_BACKEND_IO_URING;
5use super::types::{CONNECT_BACKEND_TOKIO, PgConnection};
6use crate::driver::stream::PgStream;
7use crate::driver::{
8    EnterpriseAuthMechanism, GssTokenProvider, GssTokenProviderEx, GssTokenRequest, PgError,
9    PgResult, ScramChannelBindingMode,
10};
11
12pub(super) fn generate_gss_token(
13    session_id: u64,
14    mechanism: EnterpriseAuthMechanism,
15    server_token: Option<&[u8]>,
16    legacy_provider: Option<GssTokenProvider>,
17    stateful_provider: Option<&GssTokenProviderEx>,
18) -> Result<Vec<u8>, String> {
19    if let Some(provider) = stateful_provider {
20        return provider(GssTokenRequest {
21            session_id,
22            mechanism,
23            server_token,
24        });
25    }
26
27    if let Some(provider) = legacy_provider {
28        return provider(mechanism, server_token);
29    }
30
31    Err("No GSS token provider configured".to_string())
32}
33
34pub(super) fn plain_connect_attempt_backend() -> &'static str {
35    #[cfg(all(target_os = "linux", feature = "io_uring"))]
36    {
37        if should_try_uring_plain() {
38            return CONNECT_BACKEND_IO_URING;
39        }
40    }
41    CONNECT_BACKEND_TOKIO
42}
43
44pub(super) fn connect_backend_for_stream(stream: &PgStream) -> &'static str {
45    match stream {
46        PgStream::Tcp(_) => CONNECT_BACKEND_TOKIO,
47        #[cfg(all(target_os = "linux", feature = "io_uring"))]
48        PgStream::Uring(_) => CONNECT_BACKEND_IO_URING,
49        PgStream::Tls(_) => CONNECT_BACKEND_TOKIO,
50        #[cfg(unix)]
51        PgStream::Unix(_) => CONNECT_BACKEND_TOKIO,
52        #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
53        PgStream::GssEnc(_) => CONNECT_BACKEND_TOKIO,
54    }
55}
56
57pub(super) fn connect_error_kind(error: &PgError) -> &'static str {
58    match error {
59        PgError::Connection(_) => "connection",
60        PgError::Protocol(_) => "protocol",
61        PgError::Auth(_) => "auth",
62        PgError::Query(_) | PgError::QueryServer(_) => "query",
63        PgError::NoRows => "no_rows",
64        PgError::Io(_) => "io",
65        PgError::Encode(_) => "encode",
66        PgError::Timeout(_) => "timeout",
67        PgError::PoolExhausted { .. } => "pool_exhausted",
68        PgError::PoolClosed => "pool_closed",
69    }
70}
71
72pub(super) fn record_connect_attempt(transport: &'static str, backend: &'static str) {
73    metrics::counter!(
74        "qail_pg_connect_attempt_total",
75        "transport" => transport,
76        "backend" => backend
77    )
78    .increment(1);
79}
80
81pub(super) fn record_connect_result(
82    transport: &'static str,
83    backend: &'static str,
84    result: &PgResult<PgConnection>,
85    elapsed: std::time::Duration,
86) {
87    let outcome = if result.is_ok() { "success" } else { "error" };
88    metrics::histogram!(
89        "qail_pg_connect_duration_seconds",
90        "transport" => transport,
91        "backend" => backend,
92        "outcome" => outcome
93    )
94    .record(elapsed.as_secs_f64());
95
96    if let Err(error) = result {
97        metrics::counter!(
98            "qail_pg_connect_failure_total",
99            "transport" => transport,
100            "backend" => backend,
101            "error_kind" => connect_error_kind(error)
102        )
103        .increment(1);
104    } else {
105        metrics::counter!(
106            "qail_pg_connect_success_total",
107            "transport" => transport,
108            "backend" => backend
109        )
110        .increment(1);
111    }
112}
113
114pub(super) fn select_scram_mechanism(
115    mechanisms: &[String],
116    tls_server_end_point_binding: Option<Vec<u8>>,
117    channel_binding_mode: ScramChannelBindingMode,
118) -> Result<(String, Option<Vec<u8>>), String> {
119    let has_scram = mechanisms.iter().any(|m| m == "SCRAM-SHA-256");
120    let has_scram_plus = mechanisms.iter().any(|m| m == "SCRAM-SHA-256-PLUS");
121
122    match channel_binding_mode {
123        ScramChannelBindingMode::Disable => {
124            if has_scram {
125                return Ok(("SCRAM-SHA-256".to_string(), None));
126            }
127            Err(format!(
128                "channel_binding=disable, but server does not advertise SCRAM-SHA-256. Available: {:?}",
129                mechanisms
130            ))
131        }
132        ScramChannelBindingMode::Prefer => {
133            if has_scram_plus {
134                if let Some(binding) = tls_server_end_point_binding {
135                    return Ok(("SCRAM-SHA-256-PLUS".to_string(), Some(binding)));
136                }
137
138                if has_scram {
139                    return Ok(("SCRAM-SHA-256".to_string(), None));
140                }
141
142                return Err(
143                    "Server requires SCRAM-SHA-256-PLUS but TLS channel binding is unavailable"
144                        .to_string(),
145                );
146            }
147
148            if has_scram {
149                return Ok(("SCRAM-SHA-256".to_string(), None));
150            }
151
152            Err(format!(
153                "Server doesn't support SCRAM-SHA-256. Available: {:?}",
154                mechanisms
155            ))
156        }
157        ScramChannelBindingMode::Require => {
158            if !has_scram_plus {
159                return Err(
160                    "channel_binding=require, but server does not advertise SCRAM-SHA-256-PLUS"
161                        .to_string(),
162                );
163            }
164            let binding = tls_server_end_point_binding.ok_or_else(|| {
165                "channel_binding=require, but TLS channel binding data is unavailable".to_string()
166            })?;
167            Ok(("SCRAM-SHA-256-PLUS".to_string(), Some(binding)))
168        }
169    }
170}
171
172/// PostgreSQL MD5 password response: `md5` + md5(hex(md5(password + user)) + 4-byte salt).
173pub(super) fn md5_password_message(user: &str, password: &str, salt: [u8; 4]) -> String {
174    use md5::{Digest, Md5};
175
176    let mut inner = Md5::new();
177    inner.update(password.as_bytes());
178    inner.update(user.as_bytes());
179    let inner_hex = format!("{:x}", inner.finalize());
180
181    let mut outer = Md5::new();
182    outer.update(inner_hex.as_bytes());
183    outer.update(salt);
184    format!("md5{:x}", outer.finalize())
185}
186
187/// Drop implementation sends Terminate packet if possible.
188/// This ensures proper cleanup even without explicit close() call.
189impl Drop for PgConnection {
190    fn drop(&mut self) {
191        // Try to send Terminate packet synchronously using try_write
192        // This is best-effort - if it fails, TCP RST will handle cleanup
193        let terminate: [u8; 5] = [b'X', 0, 0, 0, 4];
194
195        match &mut self.stream {
196            PgStream::Tcp(tcp) => {
197                // try_write is non-blocking
198                let _ = tcp.try_write(&terminate);
199            }
200            #[cfg(all(target_os = "linux", feature = "io_uring"))]
201            PgStream::Uring(stream) => {
202                // io_uring transport uses blocking worker operations;
203                // terminate packet in Drop is not viable, but force socket
204                // shutdown so timed-out worker ops unblock promptly.
205                let _ = stream.abort_inflight();
206            }
207            PgStream::Tls(_) => {
208                // TLS requires async write which we can't do in Drop.
209                // The TCP connection close will still notify the server.
210                // For graceful TLS shutdown, use connection.close() explicitly.
211            }
212            #[cfg(unix)]
213            PgStream::Unix(unix) => {
214                let _ = unix.try_write(&terminate);
215            }
216            #[cfg(all(feature = "enterprise-gssapi", target_os = "linux"))]
217            PgStream::GssEnc(_) => {
218                // GSSENC requires async wrap+write; skip in Drop.
219            }
220        }
221    }
222}
223
224pub(crate) fn parse_affected_rows(tag: &str) -> u64 {
225    tag.split_whitespace()
226        .last()
227        .and_then(|s| s.parse().ok())
228        .unwrap_or(0)
229}
230
231#[cfg(all(target_os = "linux", feature = "io_uring"))]
232pub(super) fn should_try_uring_plain() -> bool {
233    super::super::io_backend::should_use_uring_plain_transport()
234}