Skip to main content

heliosdb_proxy/backend/
client.rs

1//! Backend PostgreSQL client.
2//!
3//! Originates simple-query SQL against a backend node. Frames every
4//! message through the existing `crate::protocol` types so we keep a
5//! single wire-protocol implementation in the crate.
6//!
7//! The code path here is distinct from `ProxyServer::route_and_forward`,
8//! which *forwards* client frames. That path remains zero-copy;
9//! this path is for **internal TR management** queries (health check,
10//! `pg_is_in_recovery()`, `pg_promote()`, WAL-position probes, failover
11//! replay, session-state restoration).
12
13use super::auth::{md5_password_response, Scram};
14use super::error::{BackendError, BackendResult};
15use super::stream::Stream;
16use super::tls::{negotiate, TlsMode};
17use super::types::{encode_literal, ParamValue, TextValue};
18use crate::protocol::{Message, MessageType, ProtocolCodec};
19use bytes::{Buf, BufMut, BytesMut};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::TcpStream;
24
25/// Backend connection parameters.
26#[derive(Debug, Clone)]
27pub struct BackendConfig {
28    /// Hostname (also used for TLS SNI).
29    pub host: String,
30    /// TCP port.
31    pub port: u16,
32    /// PostgreSQL user.
33    pub user: String,
34    /// PostgreSQL password. If `None`, only `AuthenticationOk` (no
35    /// password required) is accepted.
36    pub password: Option<String>,
37    /// Target database. If `None`, connects to the user's default.
38    pub database: Option<String>,
39    /// Application name reported via the `application_name` startup
40    /// parameter. Defaults to `heliosdb-proxy` when `None`.
41    pub application_name: Option<String>,
42    /// TLS policy.
43    pub tls_mode: TlsMode,
44    /// Connect-timeout ceiling (covers DNS + TCP + TLS + startup).
45    pub connect_timeout: Duration,
46    /// Per-query ceiling (round-trip from Query send to ReadyForQuery).
47    pub query_timeout: Duration,
48    /// Shared rustls `ClientConfig` — build once via
49    /// `super::tls::default_client_config` and reuse across connections.
50    pub tls_config: Arc<rustls::ClientConfig>,
51}
52
53impl BackendConfig {
54    pub fn address(&self) -> String {
55        format!("{}:{}", self.host, self.port)
56    }
57}
58
59/// An established, authenticated client connection to a backend.
60pub struct BackendClient {
61    stream: Stream,
62    /// Parameter values the server sent during startup (client_encoding,
63    /// server_version, TimeZone, …). Useful for diagnostics.
64    pub server_parameters: std::collections::HashMap<String, String>,
65    /// BackendKeyData cached for potential cancel requests.
66    pub backend_pid: Option<u32>,
67    pub backend_secret: Option<u32>,
68}
69
70impl BackendClient {
71    /// Connect, TLS-negotiate, authenticate, and drain server-initialisation
72    /// frames through `ReadyForQuery`. On success the client is idle and
73    /// ready to run SQL.
74    pub async fn connect(cfg: &BackendConfig) -> BackendResult<Self> {
75        tokio::time::timeout(cfg.connect_timeout, Self::connect_inner(cfg))
76            .await
77            .map_err(|_| BackendError::Io(std::io::Error::new(
78                std::io::ErrorKind::TimedOut,
79                format!("connect to {} exceeded {:?}", cfg.address(), cfg.connect_timeout),
80            )))?
81    }
82
83    async fn connect_inner(cfg: &BackendConfig) -> BackendResult<Self> {
84        let tcp = TcpStream::connect(cfg.address()).await?;
85        let mut stream =
86            negotiate(tcp, cfg.tls_mode, cfg.tls_config.clone(), &cfg.host).await?;
87
88        // Send StartupMessage (protocol 3.0).
89        let startup = build_startup(cfg);
90        stream.write_all(&startup).await?;
91
92        let mut server_parameters = std::collections::HashMap::new();
93        let mut backend_pid = None;
94        let mut backend_secret = None;
95        let mut buffer = BytesMut::with_capacity(4096);
96        let codec = ProtocolCodec::new();
97        let mut scram_state: Option<Scram> = None;
98
99        loop {
100            let msg = read_one(&mut stream, &mut buffer, &codec).await?;
101            match msg.msg_type {
102                MessageType::AuthRequest => {
103                    handle_auth(
104                        &mut stream,
105                        &msg,
106                        cfg,
107                        &mut scram_state,
108                    )
109                    .await?;
110                }
111                MessageType::ParameterStatus => {
112                    if let Some((k, v)) = parse_parameter_status(&msg.payload) {
113                        server_parameters.insert(k, v);
114                    }
115                }
116                MessageType::BackendKeyData => {
117                    if msg.payload.len() >= 8 {
118                        backend_pid = Some(u32::from_be_bytes(
119                            msg.payload[0..4].try_into().unwrap(),
120                        ));
121                        backend_secret = Some(u32::from_be_bytes(
122                            msg.payload[4..8].try_into().unwrap(),
123                        ));
124                    }
125                }
126                MessageType::ReadyForQuery => {
127                    return Ok(Self {
128                        stream,
129                        server_parameters,
130                        backend_pid,
131                        backend_secret,
132                    });
133                }
134                MessageType::ErrorResponse => {
135                    return Err(BackendError::BackendError(error_message(&msg.payload)));
136                }
137                MessageType::NoticeResponse => {
138                    // Ignore warnings during startup.
139                }
140                other => {
141                    return Err(BackendError::Protocol(format!(
142                        "unexpected message during startup: {:?}",
143                        other
144                    )));
145                }
146            }
147        }
148    }
149
150    /// Run a simple-query (SQL text) and collect every resulting row
151    /// into a `QueryResult`. For statements that don't return rows
152    /// (DDL, `SET`, etc.) `rows` will be empty and `command_tag`
153    /// carries the completion string.
154    pub async fn simple_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
155        self.run_query(sql).await
156    }
157
158    /// Like `simple_query` but substitutes `$1`, `$2`, … with
159    /// text-format literals before sending. We stick to simple-query
160    /// rather than the extended protocol to keep the surface narrow.
161    pub async fn query_with_params(
162        &mut self,
163        sql: &str,
164        params: &[ParamValue],
165    ) -> BackendResult<QueryResult> {
166        let substituted = interpolate_params(sql, params)?;
167        self.run_query(&substituted).await
168    }
169
170    /// Shorthand for a scalar lookup: runs `sql`, expects 1 column, 1 row.
171    pub async fn query_scalar(&mut self, sql: &str) -> BackendResult<TextValue> {
172        let res = self.simple_query(sql).await?;
173        if res.rows.len() != 1 {
174            return Err(BackendError::Protocol(format!(
175                "expected 1 row, got {}",
176                res.rows.len()
177            )));
178        }
179        if res.columns.len() != 1 {
180            return Err(BackendError::Protocol(format!(
181                "expected 1 column, got {}",
182                res.columns.len()
183            )));
184        }
185        Ok(res.rows.into_iter().next().unwrap().into_iter().next().unwrap())
186    }
187
188    /// Run a statement with no result set (DDL, SET, DO). Returns the
189    /// command tag (e.g. `"SET"`, `"CREATE TABLE"`).
190    pub async fn execute(&mut self, sql: &str) -> BackendResult<String> {
191        let res = self.simple_query(sql).await?;
192        Ok(res.command_tag)
193    }
194
195    /// Bulk-load rows via `COPY <table> [(cols)] FROM STDIN` (text format).
196    /// `data` is the pre-encoded COPY text payload (rows already tab-delimited,
197    /// escaped, and newline-terminated; no trailing `\.`). Returns the row count
198    /// from the `COPY n` command tag.
199    ///
200    /// On ANY failure the connection is drained back to `ReadyForQuery`, so a
201    /// caller can fall back to per-row INSERTs cleanly — `COPY` is atomic, so a
202    /// failed load leaves zero rows behind (no double-insert risk).
203    pub async fn copy_in(&mut self, copy_sql: &str, data: &[u8]) -> BackendResult<u64> {
204        // Bulk loads can run long; give COPY a generous ceiling vs the 30s
205        // management-query default.
206        let t = Duration::from_secs(600);
207        tokio::time::timeout(t, Self::copy_in_inner(&mut self.stream, copy_sql, data))
208            .await
209            .map_err(|_| {
210                BackendError::Io(std::io::Error::new(
211                    std::io::ErrorKind::TimedOut,
212                    format!("COPY exceeded {:?}", t),
213                ))
214            })?
215    }
216
217    async fn copy_in_inner(stream: &mut Stream, copy_sql: &str, data: &[u8]) -> BackendResult<u64> {
218        // 1. Send the COPY ... FROM STDIN command as a simple Query.
219        let mut payload = BytesMut::with_capacity(copy_sql.len() + 1);
220        payload.extend_from_slice(copy_sql.as_bytes());
221        payload.put_u8(0);
222        stream
223            .write_all(&Message::new(MessageType::Query, payload).encode())
224            .await?;
225
226        let mut buffer = BytesMut::with_capacity(8192);
227        let codec = ProtocolCodec::new();
228
229        // 2. Expect CopyInResponse. Its tag 'G' isn't in the shared decoder, so
230        //    it surfaces as Unknown(b'G'). An ErrorResponse here means the
231        //    backend rejected COPY (e.g. unsupported) — drain to RFQ and surface
232        //    a recoverable error so the caller falls back to INSERTs.
233        loop {
234            let msg = read_one(stream, &mut buffer, &codec).await?;
235            match msg.msg_type {
236                MessageType::Unknown(b'G') => break,
237                MessageType::ErrorResponse => {
238                    let e = error_message(&msg.payload);
239                    Self::drain_to_ready(stream, &mut buffer, &codec).await?;
240                    return Err(BackendError::BackendError(e));
241                }
242                MessageType::ReadyForQuery => {
243                    return Err(BackendError::Protocol(
244                        "COPY: ReadyForQuery before CopyInResponse".into(),
245                    ));
246                }
247                _ => {} // NoticeResponse / ParameterStatus / etc — keep waiting
248            }
249        }
250
251        // 3. Stream the payload as CopyData frames, then CopyDone.
252        const CHUNK: usize = 64 * 1024;
253        let mut off = 0;
254        while off < data.len() {
255            let end = (off + CHUNK).min(data.len());
256            let mut p = BytesMut::with_capacity(end - off);
257            p.extend_from_slice(&data[off..end]);
258            stream
259                .write_all(&Message::new(MessageType::CopyData, p).encode())
260                .await?;
261            off = end;
262        }
263        stream
264            .write_all(&Message::new(MessageType::CopyDone, BytesMut::new()).encode())
265            .await?;
266
267        // 4. Read to ReadyForQuery: CommandComplete "COPY n" or ErrorResponse.
268        let mut tag = String::new();
269        let mut last_error = None;
270        loop {
271            let msg = read_one(stream, &mut buffer, &codec).await?;
272            match msg.msg_type {
273                MessageType::CommandComplete | MessageType::Close => {
274                    tag = parse_cstring(&msg.payload);
275                }
276                MessageType::ErrorResponse => {
277                    last_error = Some(error_message(&msg.payload));
278                }
279                MessageType::ReadyForQuery => {
280                    if let Some(e) = last_error {
281                        return Err(BackendError::BackendError(e));
282                    }
283                    // "COPY n" -> n
284                    let n = tag
285                        .rsplit(' ')
286                        .next()
287                        .and_then(|s| s.parse::<u64>().ok())
288                        .unwrap_or(0);
289                    return Ok(n);
290                }
291                _ => {}
292            }
293        }
294    }
295
296    async fn drain_to_ready(
297        stream: &mut Stream,
298        buffer: &mut BytesMut,
299        codec: &ProtocolCodec,
300    ) -> BackendResult<()> {
301        loop {
302            if read_one(stream, buffer, codec).await?.msg_type == MessageType::ReadyForQuery {
303                return Ok(());
304            }
305        }
306    }
307
308    async fn run_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
309        let t = self.stream_query_timeout();
310        tokio::time::timeout(t, Self::run_query_inner(&mut self.stream, sql))
311            .await
312            .map_err(|_| BackendError::Io(std::io::Error::new(
313                std::io::ErrorKind::TimedOut,
314                format!("query exceeded {:?}: {}", t, truncate(sql, 64)),
315            )))?
316    }
317
318    fn stream_query_timeout(&self) -> Duration {
319        // 30 seconds is a sane default for management queries; callers
320        // that need longer can wrap their own timeout around this call.
321        Duration::from_secs(30)
322    }
323
324    async fn run_query_inner(stream: &mut Stream, sql: &str) -> BackendResult<QueryResult> {
325        // Build and send a Query message (tag 'Q', payload = SQL + \0).
326        let mut payload = BytesMut::with_capacity(sql.len() + 1);
327        payload.extend_from_slice(sql.as_bytes());
328        payload.put_u8(0);
329        let frame = Message::new(MessageType::Query, payload).encode();
330        stream.write_all(&frame).await?;
331
332        let mut buffer = BytesMut::with_capacity(8192);
333        let codec = ProtocolCodec::new();
334        let mut columns: Vec<ColumnMeta> = Vec::new();
335        let mut rows: Vec<Vec<TextValue>> = Vec::new();
336        let mut command_tag = String::new();
337        let mut last_error: Option<String> = None;
338
339        loop {
340            let msg = read_one(stream, &mut buffer, &codec).await?;
341            match msg.msg_type {
342                // Both 'T' (RowDescription) and 'C' (CommandComplete)
343                // may appear; from_tag conflates T with… nothing. It
344                // DOES conflate D with Describe, but on a server→client
345                // frame the D tag here is always DataRow in practice,
346                // because we only arrive at run_query_inner after the
347                // startup handshake and the server never sends Describe.
348                MessageType::RowDescription => {
349                    columns = parse_row_description(&msg.payload);
350                }
351                MessageType::DataRow => {
352                    let row = parse_data_row(&msg.payload, columns.len())?;
353                    rows.push(row);
354                }
355                // PG tag 'C' = CommandComplete (server → client). The
356                // shared MessageType enum also has Close (client → server)
357                // under the same tag; again, the direction fixes the
358                // ambiguity at runtime.
359                MessageType::CommandComplete | MessageType::Close => {
360                    command_tag = parse_cstring(&msg.payload);
361                }
362                MessageType::EmptyQueryResponse => {
363                    command_tag = String::new();
364                }
365                MessageType::ErrorResponse => {
366                    last_error = Some(error_message(&msg.payload));
367                }
368                MessageType::NoticeResponse => {
369                    tracing::debug!(notice = %error_message(&msg.payload), "backend notice");
370                }
371                MessageType::ReadyForQuery => {
372                    if let Some(e) = last_error {
373                        return Err(BackendError::BackendError(e));
374                    }
375                    return Ok(QueryResult {
376                        columns,
377                        rows,
378                        command_tag,
379                    });
380                }
381                MessageType::ParameterStatus => {
382                    // Server may push parameter changes (e.g. after SET).
383                    // Ignore here; callers that care can call
384                    // simple_query("SHOW <param>") afterwards.
385                }
386                _other => {
387                    // Unknown message kind — keep draining until
388                    // ReadyForQuery. A well-behaved PG server won't
389                    // send anything strange outside of the above set.
390                }
391            }
392        }
393    }
394
395    /// Close the connection gracefully (send Terminate, close socket).
396    pub async fn close(mut self) {
397        let term = Message::new(MessageType::Terminate, BytesMut::new()).encode();
398        let _ = self.stream.write_all(&term).await;
399        let _ = self.stream.shutdown().await;
400    }
401
402    /// Report whether the underlying connection is over TLS.
403    pub fn is_tls(&self) -> bool {
404        self.stream.is_tls()
405    }
406}
407
408// ---------------------------------------------------------------------
409// Result shape
410// ---------------------------------------------------------------------
411
412#[derive(Debug, Clone)]
413pub struct ColumnMeta {
414    pub name: String,
415    pub type_oid: u32,
416}
417
418#[derive(Debug, Clone)]
419pub struct QueryResult {
420    pub columns: Vec<ColumnMeta>,
421    pub rows: Vec<Vec<TextValue>>,
422    pub command_tag: String,
423}
424
425impl QueryResult {
426    /// Return the numeric suffix of a CommandComplete tag, if any.
427    /// For example, `"INSERT 0 5"` → `Some(5)`.
428    pub fn rows_affected(&self) -> Option<u64> {
429        self.command_tag
430            .split_whitespace()
431            .last()
432            .and_then(|s| s.parse().ok())
433    }
434}
435
436// ---------------------------------------------------------------------
437// Helpers
438// ---------------------------------------------------------------------
439
440fn build_startup(cfg: &BackendConfig) -> Vec<u8> {
441    let mut payload = BytesMut::with_capacity(128);
442    // Protocol version 3.0.
443    payload.put_u32(196608);
444    put_cstring(&mut payload, "user");
445    put_cstring(&mut payload, &cfg.user);
446    if let Some(db) = &cfg.database {
447        put_cstring(&mut payload, "database");
448        put_cstring(&mut payload, db);
449    }
450    put_cstring(&mut payload, "application_name");
451    put_cstring(
452        &mut payload,
453        cfg.application_name
454            .as_deref()
455            .unwrap_or("heliosdb-proxy"),
456    );
457    put_cstring(&mut payload, "client_encoding");
458    put_cstring(&mut payload, "UTF8");
459    payload.put_u8(0); // terminator
460
461    let mut framed = BytesMut::with_capacity(payload.len() + 4);
462    framed.put_u32((payload.len() + 4) as u32);
463    framed.extend_from_slice(&payload);
464    framed.to_vec()
465}
466
467fn put_cstring(buf: &mut BytesMut, s: &str) {
468    buf.extend_from_slice(s.as_bytes());
469    buf.put_u8(0);
470}
471
472fn parse_cstring(payload: &[u8]) -> String {
473    let end = payload.iter().position(|&b| b == 0).unwrap_or(payload.len());
474    String::from_utf8_lossy(&payload[..end]).into_owned()
475}
476
477fn parse_parameter_status(payload: &[u8]) -> Option<(String, String)> {
478    let end1 = payload.iter().position(|&b| b == 0)?;
479    let key = String::from_utf8_lossy(&payload[..end1]).into_owned();
480    let rest = &payload[end1 + 1..];
481    let end2 = rest.iter().position(|&b| b == 0).unwrap_or(rest.len());
482    let value = String::from_utf8_lossy(&rest[..end2]).into_owned();
483    Some((key, value))
484}
485
486fn parse_row_description(payload: &[u8]) -> Vec<ColumnMeta> {
487    let mut p = BytesMut::from(payload);
488    if p.remaining() < 2 {
489        return Vec::new();
490    }
491    let n = p.get_u16() as usize;
492    let mut cols = Vec::with_capacity(n);
493    for _ in 0..n {
494        // cstring name
495        let end = match p.as_ref().iter().position(|&b| b == 0) {
496            Some(i) => i,
497            None => break,
498        };
499        let name = String::from_utf8_lossy(&p.as_ref()[..end]).into_owned();
500        p.advance(end + 1);
501        if p.remaining() < 18 {
502            break;
503        }
504        let _table_oid = p.get_u32();
505        let _column_number = p.get_u16();
506        let type_oid = p.get_u32();
507        let _type_len = p.get_i16();
508        let _type_mod = p.get_i32();
509        let _format_code = p.get_u16();
510        cols.push(ColumnMeta { name, type_oid });
511    }
512    cols
513}
514
515fn parse_data_row(payload: &[u8], column_count: usize) -> BackendResult<Vec<TextValue>> {
516    let mut p = BytesMut::from(payload);
517    if p.remaining() < 2 {
518        return Err(BackendError::Protocol("truncated DataRow".into()));
519    }
520    let n = p.get_u16() as usize;
521    let mut out = Vec::with_capacity(n);
522    for _ in 0..n {
523        if p.remaining() < 4 {
524            return Err(BackendError::Protocol("truncated DataRow field".into()));
525        }
526        let len = p.get_i32();
527        if len == -1 {
528            out.push(TextValue::Null);
529        } else {
530            let len = len as usize;
531            if p.remaining() < len {
532                return Err(BackendError::Protocol(
533                    "truncated DataRow value".into(),
534                ));
535            }
536            let bytes = p.split_to(len);
537            out.push(TextValue::Text(
538                String::from_utf8_lossy(&bytes).into_owned(),
539            ));
540        }
541    }
542    let _ = column_count;
543    Ok(out)
544}
545
546fn error_message(payload: &[u8]) -> String {
547    // ErrorResponse fields: { code:u8; cstring }*, terminated by code=0.
548    // The M (message) field is mandatory.
549    let mut i = 0;
550    let mut msg_field = None;
551    while i < payload.len() {
552        let code = payload[i];
553        if code == 0 {
554            break;
555        }
556        i += 1;
557        let end = match payload[i..].iter().position(|&b| b == 0) {
558            Some(e) => i + e,
559            None => payload.len(),
560        };
561        let value = String::from_utf8_lossy(&payload[i..end]).into_owned();
562        if code == b'M' {
563            msg_field = Some(value);
564        }
565        i = end + 1;
566    }
567    msg_field.unwrap_or_else(|| "<no message>".to_string())
568}
569
570async fn read_one(
571    stream: &mut Stream,
572    buffer: &mut BytesMut,
573    codec: &ProtocolCodec,
574) -> BackendResult<Message> {
575    loop {
576        if let Some(mut msg) = codec
577            .decode_message(buffer)
578            .map_err(|e| BackendError::Protocol(e.to_string()))?
579        {
580            // The shared tag decoder is direction-agnostic and resolves the
581            // tags that collide between client and server frames to their
582            // CLIENT-side meaning. `read_one` only ever reads SERVER frames,
583            // so remap those collisions to their server semantics:
584            //   'S' Sync->ParameterStatus, 'D' Describe->DataRow,
585            //   'E' Execute->ErrorResponse, 'C' Close->CommandComplete.
586            msg.msg_type = match msg.msg_type {
587                MessageType::Sync => MessageType::ParameterStatus,
588                MessageType::Describe => MessageType::DataRow,
589                MessageType::Execute => MessageType::ErrorResponse,
590                MessageType::Close => MessageType::CommandComplete,
591                other => other,
592            };
593            return Ok(msg);
594        }
595        let mut tmp = vec![0u8; 4096];
596        let n = stream.read(&mut tmp).await?;
597        if n == 0 {
598            return Err(BackendError::Closed);
599        }
600        buffer.extend_from_slice(&tmp[..n]);
601    }
602}
603
604async fn handle_auth(
605    stream: &mut Stream,
606    msg: &Message,
607    cfg: &BackendConfig,
608    scram_state: &mut Option<Scram>,
609) -> BackendResult<()> {
610    if msg.payload.len() < 4 {
611        return Err(BackendError::Protocol(
612            "AuthRequest payload < 4 bytes".into(),
613        ));
614    }
615    let code =
616        u32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
617    match code {
618        0 => Ok(()), // AuthenticationOk
619        5 => {
620            // AuthenticationMD5Password + 4-byte salt.
621            if msg.payload.len() < 8 {
622                return Err(BackendError::Protocol("AuthenticationMD5 truncated".into()));
623            }
624            let salt: [u8; 4] = [
625                msg.payload[4],
626                msg.payload[5],
627                msg.payload[6],
628                msg.payload[7],
629            ];
630            let password = cfg.password.as_deref().ok_or_else(|| {
631                BackendError::Auth("server requested MD5 but no password configured".into())
632            })?;
633            let payload = md5_password_response(&cfg.user, password, &salt);
634            write_password_message(stream, &payload).await
635        }
636        3 => {
637            // AuthenticationCleartextPassword — plain password with null terminator.
638            let password = cfg.password.as_deref().ok_or_else(|| {
639                BackendError::Auth("server requested password but none configured".into())
640            })?;
641            let mut payload = Vec::with_capacity(password.len() + 1);
642            payload.extend_from_slice(password.as_bytes());
643            payload.push(0);
644            write_password_message(stream, &payload).await
645        }
646        10 => {
647            // AuthenticationSASL: list of mechanism cstrings, then 0.
648            let mechs = parse_sasl_mechanisms(&msg.payload[4..]);
649            if !mechs.iter().any(|m| m == "SCRAM-SHA-256") {
650                return Err(BackendError::Auth(format!(
651                    "no supported SASL mechanism; server offered {:?}",
652                    mechs
653                )));
654            }
655            let nonce = generate_nonce();
656            let (scram, first) = Scram::client_first(nonce);
657            *scram_state = Some(scram);
658            write_password_message(stream, &first.0).await
659        }
660        11 => {
661            // AuthenticationSASLContinue: server-first bytes.
662            let scram = scram_state.as_mut().ok_or_else(|| {
663                BackendError::Auth("SASLContinue before SASL start".into())
664            })?;
665            let password = cfg.password.as_deref().ok_or_else(|| {
666                BackendError::Auth("SCRAM requires a password".into())
667            })?;
668            let out = scram.client_final(&msg.payload[4..], password)?;
669            write_password_message(stream, &out.0).await
670        }
671        12 => {
672            // AuthenticationSASLFinal: v=<signature>
673            let scram = scram_state.as_ref().ok_or_else(|| {
674                BackendError::Auth("SASLFinal before SASL start".into())
675            })?;
676            scram.verify_server(&msg.payload[4..])
677        }
678        other => Err(BackendError::Auth(format!(
679            "unsupported authentication request code: {}",
680            other
681        ))),
682    }
683}
684
685async fn write_password_message(
686    stream: &mut Stream,
687    payload: &[u8],
688) -> BackendResult<()> {
689    let mut buf = BytesMut::with_capacity(payload.len() + 5);
690    buf.put_u8(b'p');
691    buf.put_u32((payload.len() + 4) as u32);
692    buf.extend_from_slice(payload);
693    stream.write_all(&buf).await?;
694    Ok(())
695}
696
697fn parse_sasl_mechanisms(payload: &[u8]) -> Vec<String> {
698    let mut out = Vec::new();
699    let mut i = 0;
700    while i < payload.len() {
701        let end = match payload[i..].iter().position(|&b| b == 0) {
702            Some(e) => i + e,
703            None => payload.len(),
704        };
705        if end == i {
706            break; // list terminator
707        }
708        out.push(String::from_utf8_lossy(&payload[i..end]).into_owned());
709        i = end + 1;
710    }
711    out
712}
713
714fn generate_nonce() -> String {
715    use base64::Engine as _;
716    use rand::RngCore;
717    let mut bytes = [0u8; 18];
718    rand::thread_rng().fill_bytes(&mut bytes);
719    // URL-safe base64 without padding keeps it ASCII.
720    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
721}
722
723fn interpolate_params(sql: &str, params: &[ParamValue]) -> BackendResult<String> {
724    // Walk the SQL, replacing $N tokens outside of string literals with
725    // the encoded parameter. This is a deliberately simple interpolator
726    // for internal management queries; it does NOT try to be a full
727    // PG parser.
728    let mut out = String::with_capacity(sql.len());
729    let bytes = sql.as_bytes();
730    let mut i = 0;
731    let mut in_string = false;
732    let mut quote_char = 0u8;
733    while i < bytes.len() {
734        let b = bytes[i];
735        if in_string {
736            out.push(b as char);
737            if b == quote_char {
738                // PG doubles the quote to escape; peek ahead.
739                if i + 1 < bytes.len() && bytes[i + 1] == quote_char {
740                    out.push(quote_char as char);
741                    i += 2;
742                    continue;
743                }
744                in_string = false;
745            }
746            i += 1;
747            continue;
748        }
749        if b == b'\'' || b == b'"' {
750            in_string = true;
751            quote_char = b;
752            out.push(b as char);
753            i += 1;
754            continue;
755        }
756        if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
757            let mut j = i + 1;
758            while j < bytes.len() && bytes[j].is_ascii_digit() {
759                j += 1;
760            }
761            let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
762                .unwrap()
763                .parse()
764                .map_err(|_| {
765                    BackendError::Protocol(format!(
766                        "invalid parameter reference at byte {}",
767                        i
768                    ))
769                })?;
770            if idx == 0 || idx > params.len() {
771                return Err(BackendError::Protocol(format!(
772                    "parameter ${} out of range (have {})",
773                    idx,
774                    params.len()
775                )));
776            }
777            out.push_str(&encode_literal(&params[idx - 1]));
778            i = j;
779            continue;
780        }
781        out.push(b as char);
782        i += 1;
783    }
784    Ok(out)
785}
786
787fn truncate(s: &str, n: usize) -> &str {
788    match s.char_indices().nth(n) {
789        Some((i, _)) => &s[..i],
790        None => s,
791    }
792}
793
794#[cfg(test)]
795mod tests {
796    use super::*;
797    use crate::backend::types::ParamValue;
798
799    #[test]
800    fn test_build_startup_has_user_and_protocol_version() {
801        let cfg = BackendConfig {
802            host: "localhost".into(),
803            port: 5432,
804            user: "alice".into(),
805            password: None,
806            database: Some("app".into()),
807            application_name: None,
808            tls_mode: TlsMode::Disable,
809            connect_timeout: Duration::from_secs(5),
810            query_timeout: Duration::from_secs(5),
811            tls_config: crate::backend::tls::default_client_config(),
812        };
813        let bytes = build_startup(&cfg);
814        // First 4 bytes: length, next 4: protocol version 196608 (3 << 16).
815        assert_eq!(
816            u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
817            196608
818        );
819        assert!(bytes
820            .windows(5)
821            .any(|w| w == b"user\0"));
822        assert!(bytes
823            .windows(10)
824            .any(|w| w == b"database\0a"));
825    }
826
827    #[test]
828    fn test_interpolate_params_basic() {
829        let params = vec![
830            ParamValue::Int(42),
831            ParamValue::Text("alice".into()),
832        ];
833        let sql = "SELECT * FROM t WHERE id = $1 AND name = $2";
834        let out = interpolate_params(sql, &params).unwrap();
835        assert_eq!(out, "SELECT * FROM t WHERE id = 42 AND name = 'alice'");
836    }
837
838    #[test]
839    fn test_interpolate_params_escapes_quotes() {
840        let params = vec![ParamValue::Text("o'brien".into())];
841        let out =
842            interpolate_params("SELECT * FROM t WHERE name = $1", &params).unwrap();
843        assert_eq!(out, "SELECT * FROM t WHERE name = 'o''brien'");
844    }
845
846    #[test]
847    fn test_interpolate_params_leaves_dollar_in_string_alone() {
848        let params = vec![ParamValue::Int(1)];
849        let sql = "SELECT '$1' AS lit, $1 AS val";
850        let out = interpolate_params(sql, &params).unwrap();
851        assert_eq!(out, "SELECT '$1' AS lit, 1 AS val");
852    }
853
854    #[test]
855    fn test_interpolate_params_out_of_range() {
856        let params = vec![ParamValue::Int(1)];
857        let err = interpolate_params("SELECT $2", &params).unwrap_err();
858        assert!(matches!(err, BackendError::Protocol(_)));
859    }
860
861    #[test]
862    fn test_parse_row_description_shape() {
863        // numFields=1, name="x"\0, tableOid=0, col#=0, typeOid=23, len=4, mod=-1, fmt=0
864        let mut p = BytesMut::new();
865        p.put_u16(1);
866        p.extend_from_slice(b"x");
867        p.put_u8(0);
868        p.put_u32(0); // table oid
869        p.put_u16(0); // col #
870        p.put_u32(23); // int4
871        p.put_i16(4);
872        p.put_i32(-1);
873        p.put_u16(0);
874        let cols = parse_row_description(&p);
875        assert_eq!(cols.len(), 1);
876        assert_eq!(cols[0].name, "x");
877        assert_eq!(cols[0].type_oid, 23);
878    }
879
880    #[test]
881    fn test_parse_data_row_with_null() {
882        // numFields=2, len=1/'a', len=-1 (NULL)
883        let mut p = BytesMut::new();
884        p.put_u16(2);
885        p.put_i32(1);
886        p.extend_from_slice(b"a");
887        p.put_i32(-1);
888        let row = parse_data_row(&p, 2).unwrap();
889        assert_eq!(row.len(), 2);
890        assert_eq!(row[0], TextValue::Text("a".into()));
891        assert_eq!(row[1], TextValue::Null);
892    }
893
894    #[test]
895    fn test_error_message_extracts_m_field() {
896        let mut p = Vec::new();
897        p.push(b'S');
898        p.extend_from_slice(b"ERROR\0");
899        p.push(b'C');
900        p.extend_from_slice(b"28P01\0");
901        p.push(b'M');
902        p.extend_from_slice(b"password authentication failed\0");
903        p.push(0);
904        assert_eq!(error_message(&p), "password authentication failed");
905    }
906
907    #[test]
908    fn test_parse_parameter_status() {
909        let mut p = Vec::new();
910        p.extend_from_slice(b"client_encoding\0");
911        p.extend_from_slice(b"UTF8\0");
912        let (k, v) = parse_parameter_status(&p).unwrap();
913        assert_eq!(k, "client_encoding");
914        assert_eq!(v, "UTF8");
915    }
916
917    #[test]
918    fn test_parse_sasl_mechanisms() {
919        let mut p = Vec::new();
920        p.extend_from_slice(b"SCRAM-SHA-256\0");
921        p.extend_from_slice(b"SCRAM-SHA-256-PLUS\0");
922        p.push(0);
923        let m = parse_sasl_mechanisms(&p);
924        assert_eq!(m.len(), 2);
925        assert_eq!(m[0], "SCRAM-SHA-256");
926        assert_eq!(m[1], "SCRAM-SHA-256-PLUS");
927    }
928
929    #[test]
930    fn test_generate_nonce_is_url_safe() {
931        let n = generate_nonce();
932        assert!(n.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
933        assert!(n.len() >= 18);
934    }
935
936    #[test]
937    fn test_query_result_rows_affected() {
938        let r = QueryResult {
939            columns: Vec::new(),
940            rows: Vec::new(),
941            command_tag: "INSERT 0 5".into(),
942        };
943        assert_eq!(r.rows_affected(), Some(5));
944        let r = QueryResult {
945            columns: Vec::new(),
946            rows: Vec::new(),
947            command_tag: "SET".into(),
948        };
949        assert_eq!(r.rows_affected(), None);
950    }
951}