Skip to main content

heliosdb_proxy/backend/
client.rs

1//! Backend PostgreSQL client.
2//!
3//! Originates simple-query SQL against a backend node. Frames every
4//! message through the existing `crate::protocol` types so we keep a
5//! single wire-protocol implementation in the crate.
6//!
7//! The code path here is distinct from `ProxyServer::route_and_forward`,
8//! which *forwards* client frames. That path remains zero-copy;
9//! this path is for **internal TR management** queries (health check,
10//! `pg_is_in_recovery()`, `pg_promote()`, WAL-position probes, failover
11//! replay, session-state restoration).
12
13use super::auth::{md5_password_response, Scram};
14use super::error::{BackendError, BackendResult};
15use super::stream::Stream;
16use super::tls::{negotiate, TlsMode};
17use super::types::{encode_literal, ParamValue, TextValue};
18use crate::protocol::{Message, MessageType, ProtocolCodec};
19use bytes::{Buf, BufMut, BytesMut};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::TcpStream;
24
25/// Backend connection parameters.
26#[derive(Debug, Clone)]
27pub struct BackendConfig {
28    /// Hostname (also used for TLS SNI).
29    pub host: String,
30    /// TCP port.
31    pub port: u16,
32    /// PostgreSQL user.
33    pub user: String,
34    /// PostgreSQL password. If `None`, only `AuthenticationOk` (no
35    /// password required) is accepted.
36    pub password: Option<String>,
37    /// Target database. If `None`, connects to the user's default.
38    pub database: Option<String>,
39    /// Application name reported via the `application_name` startup
40    /// parameter. Defaults to `heliosdb-proxy` when `None`.
41    pub application_name: Option<String>,
42    /// TLS policy.
43    pub tls_mode: TlsMode,
44    /// Connect-timeout ceiling (covers DNS + TCP + TLS + startup).
45    pub connect_timeout: Duration,
46    /// Per-query ceiling (round-trip from Query send to ReadyForQuery).
47    pub query_timeout: Duration,
48    /// Shared rustls `ClientConfig` — build once via
49    /// `super::tls::default_client_config` and reuse across connections.
50    pub tls_config: Arc<rustls::ClientConfig>,
51}
52
53impl BackendConfig {
54    pub fn address(&self) -> String {
55        format!("{}:{}", self.host, self.port)
56    }
57}
58
59/// An established, authenticated client connection to a backend.
60pub struct BackendClient {
61    stream: Stream,
62    /// Parameter values the server sent during startup (client_encoding,
63    /// server_version, TimeZone, …). Useful for diagnostics.
64    pub server_parameters: std::collections::HashMap<String, String>,
65    /// BackendKeyData cached for potential cancel requests.
66    pub backend_pid: Option<u32>,
67    pub backend_secret: Option<u32>,
68}
69
70impl BackendClient {
71    /// Connect, TLS-negotiate, authenticate, and drain server-initialisation
72    /// frames through `ReadyForQuery`. On success the client is idle and
73    /// ready to run SQL.
74    pub async fn connect(cfg: &BackendConfig) -> BackendResult<Self> {
75        tokio::time::timeout(cfg.connect_timeout, Self::connect_inner(cfg))
76            .await
77            .map_err(|_| {
78                BackendError::Io(std::io::Error::new(
79                    std::io::ErrorKind::TimedOut,
80                    format!(
81                        "connect to {} exceeded {:?}",
82                        cfg.address(),
83                        cfg.connect_timeout
84                    ),
85                ))
86            })?
87    }
88
89    async fn connect_inner(cfg: &BackendConfig) -> BackendResult<Self> {
90        let tcp = TcpStream::connect(cfg.address()).await?;
91        let mut stream = negotiate(tcp, cfg.tls_mode, cfg.tls_config.clone(), &cfg.host).await?;
92
93        // Send StartupMessage (protocol 3.0).
94        let startup = build_startup(cfg);
95        stream.write_all(&startup).await?;
96
97        let mut server_parameters = std::collections::HashMap::new();
98        let mut backend_pid = None;
99        let mut backend_secret = None;
100        let mut buffer = BytesMut::with_capacity(4096);
101        let codec = ProtocolCodec::new();
102        let mut scram_state: Option<Scram> = None;
103
104        loop {
105            let msg = read_one(&mut stream, &mut buffer, &codec).await?;
106            match msg.msg_type {
107                MessageType::AuthRequest => {
108                    handle_auth(&mut stream, &msg, cfg, &mut scram_state).await?;
109                }
110                MessageType::ParameterStatus => {
111                    if let Some((k, v)) = parse_parameter_status(&msg.payload) {
112                        server_parameters.insert(k, v);
113                    }
114                }
115                MessageType::BackendKeyData => {
116                    if msg.payload.len() >= 8 {
117                        backend_pid =
118                            Some(u32::from_be_bytes(msg.payload[0..4].try_into().unwrap()));
119                        backend_secret =
120                            Some(u32::from_be_bytes(msg.payload[4..8].try_into().unwrap()));
121                    }
122                }
123                MessageType::ReadyForQuery => {
124                    return Ok(Self {
125                        stream,
126                        server_parameters,
127                        backend_pid,
128                        backend_secret,
129                    });
130                }
131                MessageType::ErrorResponse => {
132                    return Err(BackendError::BackendError(error_message(&msg.payload)));
133                }
134                MessageType::NoticeResponse => {
135                    // Ignore warnings during startup.
136                }
137                other => {
138                    return Err(BackendError::Protocol(format!(
139                        "unexpected message during startup: {:?}",
140                        other
141                    )));
142                }
143            }
144        }
145    }
146
147    /// Run a simple-query (SQL text) and collect every resulting row
148    /// into a `QueryResult`. For statements that don't return rows
149    /// (DDL, `SET`, etc.) `rows` will be empty and `command_tag`
150    /// carries the completion string.
151    pub async fn simple_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
152        self.run_query(sql).await
153    }
154
155    /// Like `simple_query` but substitutes `$1`, `$2`, … with
156    /// text-format literals before sending. We stick to simple-query
157    /// rather than the extended protocol to keep the surface narrow.
158    pub async fn query_with_params(
159        &mut self,
160        sql: &str,
161        params: &[ParamValue],
162    ) -> BackendResult<QueryResult> {
163        let substituted = interpolate_params(sql, params)?;
164        self.run_query(&substituted).await
165    }
166
167    /// Shorthand for a scalar lookup: runs `sql`, expects 1 column, 1 row.
168    pub async fn query_scalar(&mut self, sql: &str) -> BackendResult<TextValue> {
169        let res = self.simple_query(sql).await?;
170        if res.rows.len() != 1 {
171            return Err(BackendError::Protocol(format!(
172                "expected 1 row, got {}",
173                res.rows.len()
174            )));
175        }
176        if res.columns.len() != 1 {
177            return Err(BackendError::Protocol(format!(
178                "expected 1 column, got {}",
179                res.columns.len()
180            )));
181        }
182        Ok(res
183            .rows
184            .into_iter()
185            .next()
186            .unwrap()
187            .into_iter()
188            .next()
189            .unwrap())
190    }
191
192    /// Run a statement with no result set (DDL, SET, DO). Returns the
193    /// command tag (e.g. `"SET"`, `"CREATE TABLE"`).
194    pub async fn execute(&mut self, sql: &str) -> BackendResult<String> {
195        let res = self.simple_query(sql).await?;
196        Ok(res.command_tag)
197    }
198
199    /// Bulk-load rows via `COPY <table> [(cols)] FROM STDIN` (text format).
200    /// `data` is the pre-encoded COPY text payload (rows already tab-delimited,
201    /// escaped, and newline-terminated; no trailing `\.`). Returns the row count
202    /// from the `COPY n` command tag.
203    ///
204    /// On ANY failure the connection is drained back to `ReadyForQuery`, so a
205    /// caller can fall back to per-row INSERTs cleanly — `COPY` is atomic, so a
206    /// failed load leaves zero rows behind (no double-insert risk).
207    pub async fn copy_in(&mut self, copy_sql: &str, data: &[u8]) -> BackendResult<u64> {
208        // Bulk loads can run long; give COPY a generous ceiling vs the 30s
209        // management-query default.
210        let t = Duration::from_secs(600);
211        tokio::time::timeout(t, Self::copy_in_inner(&mut self.stream, copy_sql, data))
212            .await
213            .map_err(|_| {
214                BackendError::Io(std::io::Error::new(
215                    std::io::ErrorKind::TimedOut,
216                    format!("COPY exceeded {:?}", t),
217                ))
218            })?
219    }
220
221    async fn copy_in_inner(stream: &mut Stream, copy_sql: &str, data: &[u8]) -> BackendResult<u64> {
222        // 1. Send the COPY ... FROM STDIN command as a simple Query.
223        let mut payload = BytesMut::with_capacity(copy_sql.len() + 1);
224        payload.extend_from_slice(copy_sql.as_bytes());
225        payload.put_u8(0);
226        stream
227            .write_all(&Message::new(MessageType::Query, payload).encode())
228            .await?;
229
230        let mut buffer = BytesMut::with_capacity(8192);
231        let codec = ProtocolCodec::new();
232
233        // 2. Expect CopyInResponse. Its tag 'G' isn't in the shared decoder, so
234        //    it surfaces as Unknown(b'G'). An ErrorResponse here means the
235        //    backend rejected COPY (e.g. unsupported) — drain to RFQ and surface
236        //    a recoverable error so the caller falls back to INSERTs.
237        loop {
238            let msg = read_one(stream, &mut buffer, &codec).await?;
239            match msg.msg_type {
240                MessageType::Unknown(b'G') => break,
241                MessageType::ErrorResponse => {
242                    let e = error_message(&msg.payload);
243                    Self::drain_to_ready(stream, &mut buffer, &codec).await?;
244                    return Err(BackendError::BackendError(e));
245                }
246                MessageType::ReadyForQuery => {
247                    return Err(BackendError::Protocol(
248                        "COPY: ReadyForQuery before CopyInResponse".into(),
249                    ));
250                }
251                _ => {} // NoticeResponse / ParameterStatus / etc — keep waiting
252            }
253        }
254
255        // 3. Stream the payload as CopyData frames, then CopyDone.
256        const CHUNK: usize = 64 * 1024;
257        let mut off = 0;
258        while off < data.len() {
259            let end = (off + CHUNK).min(data.len());
260            let mut p = BytesMut::with_capacity(end - off);
261            p.extend_from_slice(&data[off..end]);
262            stream
263                .write_all(&Message::new(MessageType::CopyData, p).encode())
264                .await?;
265            off = end;
266        }
267        stream
268            .write_all(&Message::new(MessageType::CopyDone, BytesMut::new()).encode())
269            .await?;
270
271        // 4. Read to ReadyForQuery: CommandComplete "COPY n" or ErrorResponse.
272        let mut tag = String::new();
273        let mut last_error = None;
274        loop {
275            let msg = read_one(stream, &mut buffer, &codec).await?;
276            match msg.msg_type {
277                MessageType::CommandComplete | MessageType::Close => {
278                    tag = parse_cstring(&msg.payload);
279                }
280                MessageType::ErrorResponse => {
281                    last_error = Some(error_message(&msg.payload));
282                }
283                MessageType::ReadyForQuery => {
284                    if let Some(e) = last_error {
285                        return Err(BackendError::BackendError(e));
286                    }
287                    // "COPY n" -> n
288                    let n = tag
289                        .rsplit(' ')
290                        .next()
291                        .and_then(|s| s.parse::<u64>().ok())
292                        .unwrap_or(0);
293                    return Ok(n);
294                }
295                _ => {}
296            }
297        }
298    }
299
300    async fn drain_to_ready(
301        stream: &mut Stream,
302        buffer: &mut BytesMut,
303        codec: &ProtocolCodec,
304    ) -> BackendResult<()> {
305        loop {
306            if read_one(stream, buffer, codec).await?.msg_type == MessageType::ReadyForQuery {
307                return Ok(());
308            }
309        }
310    }
311
312    async fn run_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
313        let t = self.stream_query_timeout();
314        tokio::time::timeout(t, Self::run_query_inner(&mut self.stream, sql))
315            .await
316            .map_err(|_| {
317                BackendError::Io(std::io::Error::new(
318                    std::io::ErrorKind::TimedOut,
319                    format!("query exceeded {:?}: {}", t, truncate(sql, 64)),
320                ))
321            })?
322    }
323
324    fn stream_query_timeout(&self) -> Duration {
325        // 30 seconds is a sane default for management queries; callers
326        // that need longer can wrap their own timeout around this call.
327        Duration::from_secs(30)
328    }
329
330    async fn run_query_inner(stream: &mut Stream, sql: &str) -> BackendResult<QueryResult> {
331        // Build and send a Query message (tag 'Q', payload = SQL + \0).
332        let mut payload = BytesMut::with_capacity(sql.len() + 1);
333        payload.extend_from_slice(sql.as_bytes());
334        payload.put_u8(0);
335        let frame = Message::new(MessageType::Query, payload).encode();
336        stream.write_all(&frame).await?;
337
338        let mut buffer = BytesMut::with_capacity(8192);
339        let codec = ProtocolCodec::new();
340        let mut columns: Vec<ColumnMeta> = Vec::new();
341        let mut rows: Vec<Vec<TextValue>> = Vec::new();
342        let mut command_tag = String::new();
343        let mut last_error: Option<String> = None;
344
345        loop {
346            let msg = read_one(stream, &mut buffer, &codec).await?;
347            match msg.msg_type {
348                // Both 'T' (RowDescription) and 'C' (CommandComplete)
349                // may appear; from_tag conflates T with… nothing. It
350                // DOES conflate D with Describe, but on a server→client
351                // frame the D tag here is always DataRow in practice,
352                // because we only arrive at run_query_inner after the
353                // startup handshake and the server never sends Describe.
354                MessageType::RowDescription => {
355                    columns = parse_row_description(&msg.payload);
356                }
357                MessageType::DataRow => {
358                    let row = parse_data_row(&msg.payload, columns.len())?;
359                    rows.push(row);
360                }
361                // PG tag 'C' = CommandComplete (server → client). The
362                // shared MessageType enum also has Close (client → server)
363                // under the same tag; again, the direction fixes the
364                // ambiguity at runtime.
365                MessageType::CommandComplete | MessageType::Close => {
366                    command_tag = parse_cstring(&msg.payload);
367                }
368                MessageType::EmptyQueryResponse => {
369                    command_tag = String::new();
370                }
371                MessageType::ErrorResponse => {
372                    last_error = Some(error_message(&msg.payload));
373                }
374                MessageType::NoticeResponse => {
375                    tracing::debug!(notice = %error_message(&msg.payload), "backend notice");
376                }
377                MessageType::ReadyForQuery => {
378                    if let Some(e) = last_error {
379                        return Err(BackendError::BackendError(e));
380                    }
381                    return Ok(QueryResult {
382                        columns,
383                        rows,
384                        command_tag,
385                    });
386                }
387                MessageType::ParameterStatus => {
388                    // Server may push parameter changes (e.g. after SET).
389                    // Ignore here; callers that care can call
390                    // simple_query("SHOW <param>") afterwards.
391                }
392                _other => {
393                    // Unknown message kind — keep draining until
394                    // ReadyForQuery. A well-behaved PG server won't
395                    // send anything strange outside of the above set.
396                }
397            }
398        }
399    }
400
401    /// Close the connection gracefully (send Terminate, close socket).
402    pub async fn close(mut self) {
403        let term = Message::new(MessageType::Terminate, BytesMut::new()).encode();
404        let _ = self.stream.write_all(&term).await;
405        let _ = self.stream.shutdown().await;
406    }
407
408    /// Report whether the underlying connection is over TLS.
409    pub fn is_tls(&self) -> bool {
410        self.stream.is_tls()
411    }
412}
413
414// ---------------------------------------------------------------------
415// Result shape
416// ---------------------------------------------------------------------
417
418#[derive(Debug, Clone)]
419pub struct ColumnMeta {
420    pub name: String,
421    pub type_oid: u32,
422}
423
424#[derive(Debug, Clone)]
425pub struct QueryResult {
426    pub columns: Vec<ColumnMeta>,
427    pub rows: Vec<Vec<TextValue>>,
428    pub command_tag: String,
429}
430
431impl QueryResult {
432    /// Return the numeric suffix of a CommandComplete tag, if any.
433    /// For example, `"INSERT 0 5"` → `Some(5)`.
434    pub fn rows_affected(&self) -> Option<u64> {
435        self.command_tag
436            .split_whitespace()
437            .last()
438            .and_then(|s| s.parse().ok())
439    }
440}
441
442// ---------------------------------------------------------------------
443// Helpers
444// ---------------------------------------------------------------------
445
446fn build_startup(cfg: &BackendConfig) -> Vec<u8> {
447    let mut payload = BytesMut::with_capacity(128);
448    // Protocol version 3.0.
449    payload.put_u32(196608);
450    put_cstring(&mut payload, "user");
451    put_cstring(&mut payload, &cfg.user);
452    if let Some(db) = &cfg.database {
453        put_cstring(&mut payload, "database");
454        put_cstring(&mut payload, db);
455    }
456    put_cstring(&mut payload, "application_name");
457    put_cstring(
458        &mut payload,
459        cfg.application_name.as_deref().unwrap_or("heliosdb-proxy"),
460    );
461    put_cstring(&mut payload, "client_encoding");
462    put_cstring(&mut payload, "UTF8");
463    payload.put_u8(0); // terminator
464
465    let mut framed = BytesMut::with_capacity(payload.len() + 4);
466    framed.put_u32((payload.len() + 4) as u32);
467    framed.extend_from_slice(&payload);
468    framed.to_vec()
469}
470
471fn put_cstring(buf: &mut BytesMut, s: &str) {
472    buf.extend_from_slice(s.as_bytes());
473    buf.put_u8(0);
474}
475
476fn parse_cstring(payload: &[u8]) -> String {
477    let end = payload
478        .iter()
479        .position(|&b| b == 0)
480        .unwrap_or(payload.len());
481    String::from_utf8_lossy(&payload[..end]).into_owned()
482}
483
484fn parse_parameter_status(payload: &[u8]) -> Option<(String, String)> {
485    let end1 = payload.iter().position(|&b| b == 0)?;
486    let key = String::from_utf8_lossy(&payload[..end1]).into_owned();
487    let rest = &payload[end1 + 1..];
488    let end2 = rest.iter().position(|&b| b == 0).unwrap_or(rest.len());
489    let value = String::from_utf8_lossy(&rest[..end2]).into_owned();
490    Some((key, value))
491}
492
493fn parse_row_description(payload: &[u8]) -> Vec<ColumnMeta> {
494    let mut p = BytesMut::from(payload);
495    if p.remaining() < 2 {
496        return Vec::new();
497    }
498    let n = p.get_u16() as usize;
499    let mut cols = Vec::with_capacity(n);
500    for _ in 0..n {
501        // cstring name
502        let end = match p.as_ref().iter().position(|&b| b == 0) {
503            Some(i) => i,
504            None => break,
505        };
506        let name = String::from_utf8_lossy(&p.as_ref()[..end]).into_owned();
507        p.advance(end + 1);
508        if p.remaining() < 18 {
509            break;
510        }
511        let _table_oid = p.get_u32();
512        let _column_number = p.get_u16();
513        let type_oid = p.get_u32();
514        let _type_len = p.get_i16();
515        let _type_mod = p.get_i32();
516        let _format_code = p.get_u16();
517        cols.push(ColumnMeta { name, type_oid });
518    }
519    cols
520}
521
522fn parse_data_row(payload: &[u8], column_count: usize) -> BackendResult<Vec<TextValue>> {
523    let mut p = BytesMut::from(payload);
524    if p.remaining() < 2 {
525        return Err(BackendError::Protocol("truncated DataRow".into()));
526    }
527    let n = p.get_u16() as usize;
528    let mut out = Vec::with_capacity(n);
529    for _ in 0..n {
530        if p.remaining() < 4 {
531            return Err(BackendError::Protocol("truncated DataRow field".into()));
532        }
533        let len = p.get_i32();
534        if len == -1 {
535            out.push(TextValue::Null);
536        } else {
537            let len = len as usize;
538            if p.remaining() < len {
539                return Err(BackendError::Protocol("truncated DataRow value".into()));
540            }
541            let bytes = p.split_to(len);
542            out.push(TextValue::Text(
543                String::from_utf8_lossy(&bytes).into_owned(),
544            ));
545        }
546    }
547    let _ = column_count;
548    Ok(out)
549}
550
551fn error_message(payload: &[u8]) -> String {
552    // ErrorResponse fields: { code:u8; cstring }*, terminated by code=0.
553    // The M (message) field is mandatory.
554    let mut i = 0;
555    let mut msg_field = None;
556    while i < payload.len() {
557        let code = payload[i];
558        if code == 0 {
559            break;
560        }
561        i += 1;
562        let end = match payload[i..].iter().position(|&b| b == 0) {
563            Some(e) => i + e,
564            None => payload.len(),
565        };
566        let value = String::from_utf8_lossy(&payload[i..end]).into_owned();
567        if code == b'M' {
568            msg_field = Some(value);
569        }
570        i = end + 1;
571    }
572    msg_field.unwrap_or_else(|| "<no message>".to_string())
573}
574
575async fn read_one(
576    stream: &mut Stream,
577    buffer: &mut BytesMut,
578    codec: &ProtocolCodec,
579) -> BackendResult<Message> {
580    loop {
581        if let Some(mut msg) = codec
582            .decode_message(buffer)
583            .map_err(|e| BackendError::Protocol(e.to_string()))?
584        {
585            // The shared tag decoder is direction-agnostic and resolves the
586            // tags that collide between client and server frames to their
587            // CLIENT-side meaning. `read_one` only ever reads SERVER frames,
588            // so remap those collisions to their server semantics:
589            //   'S' Sync->ParameterStatus, 'D' Describe->DataRow,
590            //   'E' Execute->ErrorResponse, 'C' Close->CommandComplete.
591            msg.msg_type = match msg.msg_type {
592                MessageType::Sync => MessageType::ParameterStatus,
593                MessageType::Describe => MessageType::DataRow,
594                MessageType::Execute => MessageType::ErrorResponse,
595                MessageType::Close => MessageType::CommandComplete,
596                other => other,
597            };
598            return Ok(msg);
599        }
600        let mut tmp = vec![0u8; 4096];
601        let n = stream.read(&mut tmp).await?;
602        if n == 0 {
603            return Err(BackendError::Closed);
604        }
605        buffer.extend_from_slice(&tmp[..n]);
606    }
607}
608
609async fn handle_auth(
610    stream: &mut Stream,
611    msg: &Message,
612    cfg: &BackendConfig,
613    scram_state: &mut Option<Scram>,
614) -> BackendResult<()> {
615    if msg.payload.len() < 4 {
616        return Err(BackendError::Protocol(
617            "AuthRequest payload < 4 bytes".into(),
618        ));
619    }
620    let code = u32::from_be_bytes([
621        msg.payload[0],
622        msg.payload[1],
623        msg.payload[2],
624        msg.payload[3],
625    ]);
626    match code {
627        0 => Ok(()), // AuthenticationOk
628        5 => {
629            // AuthenticationMD5Password + 4-byte salt.
630            if msg.payload.len() < 8 {
631                return Err(BackendError::Protocol("AuthenticationMD5 truncated".into()));
632            }
633            let salt: [u8; 4] = [
634                msg.payload[4],
635                msg.payload[5],
636                msg.payload[6],
637                msg.payload[7],
638            ];
639            let password = cfg.password.as_deref().ok_or_else(|| {
640                BackendError::Auth("server requested MD5 but no password configured".into())
641            })?;
642            let payload = md5_password_response(&cfg.user, password, &salt);
643            write_password_message(stream, &payload).await
644        }
645        3 => {
646            // AuthenticationCleartextPassword — plain password with null terminator.
647            let password = cfg.password.as_deref().ok_or_else(|| {
648                BackendError::Auth("server requested password but none configured".into())
649            })?;
650            let mut payload = Vec::with_capacity(password.len() + 1);
651            payload.extend_from_slice(password.as_bytes());
652            payload.push(0);
653            write_password_message(stream, &payload).await
654        }
655        10 => {
656            // AuthenticationSASL: list of mechanism cstrings, then 0.
657            let mechs = parse_sasl_mechanisms(&msg.payload[4..]);
658            if !mechs.iter().any(|m| m == "SCRAM-SHA-256") {
659                return Err(BackendError::Auth(format!(
660                    "no supported SASL mechanism; server offered {:?}",
661                    mechs
662                )));
663            }
664            let nonce = generate_nonce();
665            let (scram, first) = Scram::client_first(nonce);
666            *scram_state = Some(scram);
667            write_password_message(stream, &first.0).await
668        }
669        11 => {
670            // AuthenticationSASLContinue: server-first bytes.
671            let scram = scram_state
672                .as_mut()
673                .ok_or_else(|| BackendError::Auth("SASLContinue before SASL start".into()))?;
674            let password = cfg
675                .password
676                .as_deref()
677                .ok_or_else(|| BackendError::Auth("SCRAM requires a password".into()))?;
678            let out = scram.client_final(&msg.payload[4..], password)?;
679            write_password_message(stream, &out.0).await
680        }
681        12 => {
682            // AuthenticationSASLFinal: v=<signature>
683            let scram = scram_state
684                .as_ref()
685                .ok_or_else(|| BackendError::Auth("SASLFinal before SASL start".into()))?;
686            scram.verify_server(&msg.payload[4..])
687        }
688        other => Err(BackendError::Auth(format!(
689            "unsupported authentication request code: {}",
690            other
691        ))),
692    }
693}
694
695async fn write_password_message(stream: &mut Stream, payload: &[u8]) -> BackendResult<()> {
696    let mut buf = BytesMut::with_capacity(payload.len() + 5);
697    buf.put_u8(b'p');
698    buf.put_u32((payload.len() + 4) as u32);
699    buf.extend_from_slice(payload);
700    stream.write_all(&buf).await?;
701    Ok(())
702}
703
704fn parse_sasl_mechanisms(payload: &[u8]) -> Vec<String> {
705    let mut out = Vec::new();
706    let mut i = 0;
707    while i < payload.len() {
708        let end = match payload[i..].iter().position(|&b| b == 0) {
709            Some(e) => i + e,
710            None => payload.len(),
711        };
712        if end == i {
713            break; // list terminator
714        }
715        out.push(String::from_utf8_lossy(&payload[i..end]).into_owned());
716        i = end + 1;
717    }
718    out
719}
720
721fn generate_nonce() -> String {
722    use base64::Engine as _;
723    use rand::RngCore;
724    let mut bytes = [0u8; 18];
725    rand::thread_rng().fill_bytes(&mut bytes);
726    // URL-safe base64 without padding keeps it ASCII.
727    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
728}
729
730fn interpolate_params(sql: &str, params: &[ParamValue]) -> BackendResult<String> {
731    // Walk the SQL, replacing $N tokens outside of string literals with
732    // the encoded parameter. This is a deliberately simple interpolator
733    // for internal management queries; it does NOT try to be a full
734    // PG parser.
735    let mut out = String::with_capacity(sql.len());
736    let bytes = sql.as_bytes();
737    let mut i = 0;
738    let mut in_string = false;
739    let mut quote_char = 0u8;
740    while i < bytes.len() {
741        let b = bytes[i];
742        if in_string {
743            out.push(b as char);
744            if b == quote_char {
745                // PG doubles the quote to escape; peek ahead.
746                if i + 1 < bytes.len() && bytes[i + 1] == quote_char {
747                    out.push(quote_char as char);
748                    i += 2;
749                    continue;
750                }
751                in_string = false;
752            }
753            i += 1;
754            continue;
755        }
756        if b == b'\'' || b == b'"' {
757            in_string = true;
758            quote_char = b;
759            out.push(b as char);
760            i += 1;
761            continue;
762        }
763        if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
764            let mut j = i + 1;
765            while j < bytes.len() && bytes[j].is_ascii_digit() {
766                j += 1;
767            }
768            let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
769                .unwrap()
770                .parse()
771                .map_err(|_| {
772                    BackendError::Protocol(format!("invalid parameter reference at byte {}", i))
773                })?;
774            if idx == 0 || idx > params.len() {
775                return Err(BackendError::Protocol(format!(
776                    "parameter ${} out of range (have {})",
777                    idx,
778                    params.len()
779                )));
780            }
781            out.push_str(&encode_literal(&params[idx - 1]));
782            i = j;
783            continue;
784        }
785        out.push(b as char);
786        i += 1;
787    }
788    Ok(out)
789}
790
791fn truncate(s: &str, n: usize) -> &str {
792    match s.char_indices().nth(n) {
793        Some((i, _)) => &s[..i],
794        None => s,
795    }
796}
797
798#[cfg(test)]
799mod tests {
800    use super::*;
801    use crate::backend::types::ParamValue;
802
803    #[test]
804    fn test_build_startup_has_user_and_protocol_version() {
805        let cfg = BackendConfig {
806            host: "localhost".into(),
807            port: 5432,
808            user: "alice".into(),
809            password: None,
810            database: Some("app".into()),
811            application_name: None,
812            tls_mode: TlsMode::Disable,
813            connect_timeout: Duration::from_secs(5),
814            query_timeout: Duration::from_secs(5),
815            tls_config: crate::backend::tls::default_client_config(),
816        };
817        let bytes = build_startup(&cfg);
818        // First 4 bytes: length, next 4: protocol version 196608 (3 << 16).
819        assert_eq!(
820            u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
821            196608
822        );
823        assert!(bytes.windows(5).any(|w| w == b"user\0"));
824        assert!(bytes.windows(10).any(|w| w == b"database\0a"));
825    }
826
827    #[test]
828    fn test_interpolate_params_basic() {
829        let params = vec![ParamValue::Int(42), ParamValue::Text("alice".into())];
830        let sql = "SELECT * FROM t WHERE id = $1 AND name = $2";
831        let out = interpolate_params(sql, &params).unwrap();
832        assert_eq!(out, "SELECT * FROM t WHERE id = 42 AND name = 'alice'");
833    }
834
835    #[test]
836    fn test_interpolate_params_escapes_quotes() {
837        let params = vec![ParamValue::Text("o'brien".into())];
838        let out = interpolate_params("SELECT * FROM t WHERE name = $1", &params).unwrap();
839        assert_eq!(out, "SELECT * FROM t WHERE name = 'o''brien'");
840    }
841
842    #[test]
843    fn test_interpolate_params_leaves_dollar_in_string_alone() {
844        let params = vec![ParamValue::Int(1)];
845        let sql = "SELECT '$1' AS lit, $1 AS val";
846        let out = interpolate_params(sql, &params).unwrap();
847        assert_eq!(out, "SELECT '$1' AS lit, 1 AS val");
848    }
849
850    #[test]
851    fn test_interpolate_params_out_of_range() {
852        let params = vec![ParamValue::Int(1)];
853        let err = interpolate_params("SELECT $2", &params).unwrap_err();
854        assert!(matches!(err, BackendError::Protocol(_)));
855    }
856
857    #[test]
858    fn test_parse_row_description_shape() {
859        // numFields=1, name="x"\0, tableOid=0, col#=0, typeOid=23, len=4, mod=-1, fmt=0
860        let mut p = BytesMut::new();
861        p.put_u16(1);
862        p.extend_from_slice(b"x");
863        p.put_u8(0);
864        p.put_u32(0); // table oid
865        p.put_u16(0); // col #
866        p.put_u32(23); // int4
867        p.put_i16(4);
868        p.put_i32(-1);
869        p.put_u16(0);
870        let cols = parse_row_description(&p);
871        assert_eq!(cols.len(), 1);
872        assert_eq!(cols[0].name, "x");
873        assert_eq!(cols[0].type_oid, 23);
874    }
875
876    #[test]
877    fn test_parse_data_row_with_null() {
878        // numFields=2, len=1/'a', len=-1 (NULL)
879        let mut p = BytesMut::new();
880        p.put_u16(2);
881        p.put_i32(1);
882        p.extend_from_slice(b"a");
883        p.put_i32(-1);
884        let row = parse_data_row(&p, 2).unwrap();
885        assert_eq!(row.len(), 2);
886        assert_eq!(row[0], TextValue::Text("a".into()));
887        assert_eq!(row[1], TextValue::Null);
888    }
889
890    #[test]
891    fn test_error_message_extracts_m_field() {
892        let mut p = Vec::new();
893        p.push(b'S');
894        p.extend_from_slice(b"ERROR\0");
895        p.push(b'C');
896        p.extend_from_slice(b"28P01\0");
897        p.push(b'M');
898        p.extend_from_slice(b"password authentication failed\0");
899        p.push(0);
900        assert_eq!(error_message(&p), "password authentication failed");
901    }
902
903    #[test]
904    fn test_parse_parameter_status() {
905        let mut p = Vec::new();
906        p.extend_from_slice(b"client_encoding\0");
907        p.extend_from_slice(b"UTF8\0");
908        let (k, v) = parse_parameter_status(&p).unwrap();
909        assert_eq!(k, "client_encoding");
910        assert_eq!(v, "UTF8");
911    }
912
913    #[test]
914    fn test_parse_sasl_mechanisms() {
915        let mut p = Vec::new();
916        p.extend_from_slice(b"SCRAM-SHA-256\0");
917        p.extend_from_slice(b"SCRAM-SHA-256-PLUS\0");
918        p.push(0);
919        let m = parse_sasl_mechanisms(&p);
920        assert_eq!(m.len(), 2);
921        assert_eq!(m[0], "SCRAM-SHA-256");
922        assert_eq!(m[1], "SCRAM-SHA-256-PLUS");
923    }
924
925    #[test]
926    fn test_generate_nonce_is_url_safe() {
927        let n = generate_nonce();
928        assert!(n
929            .chars()
930            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
931        assert!(n.len() >= 18);
932    }
933
934    #[test]
935    fn test_query_result_rows_affected() {
936        let r = QueryResult {
937            columns: Vec::new(),
938            rows: Vec::new(),
939            command_tag: "INSERT 0 5".into(),
940        };
941        assert_eq!(r.rows_affected(), Some(5));
942        let r = QueryResult {
943            columns: Vec::new(),
944            rows: Vec::new(),
945            command_tag: "SET".into(),
946        };
947        assert_eq!(r.rows_affected(), None);
948    }
949}