Skip to main content

heliosdb_proxy/backend/
client.rs

1//! Backend PostgreSQL client.
2//!
3//! Originates simple-query SQL against a backend node. Frames every
4//! message through the existing `crate::protocol` types so we keep a
5//! single wire-protocol implementation in the crate.
6//!
7//! The code path here is distinct from `ProxyServer::route_and_forward`,
8//! which *forwards* client frames. That path remains zero-copy;
9//! this path is for **internal TR management** queries (health check,
10//! `pg_is_in_recovery()`, `pg_promote()`, WAL-position probes, failover
11//! replay, session-state restoration).
12
13use super::auth::{md5_password_response, Scram};
14use super::error::{BackendError, BackendResult};
15use super::stream::Stream;
16use super::tls::{negotiate, TlsMode};
17use super::types::{encode_literal, ParamValue, TextValue};
18use crate::protocol::{Message, MessageType, ProtocolCodec};
19use bytes::{Buf, BufMut, BytesMut};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::TcpStream;
24
25/// Backend connection parameters.
26#[derive(Debug, Clone)]
27pub struct BackendConfig {
28    /// Hostname (also used for TLS SNI).
29    pub host: String,
30    /// TCP port.
31    pub port: u16,
32    /// PostgreSQL user.
33    pub user: String,
34    /// PostgreSQL password. If `None`, only `AuthenticationOk` (no
35    /// password required) is accepted.
36    pub password: Option<String>,
37    /// Target database. If `None`, connects to the user's default.
38    pub database: Option<String>,
39    /// Application name reported via the `application_name` startup
40    /// parameter. Defaults to `heliosdb-proxy` when `None`.
41    pub application_name: Option<String>,
42    /// TLS policy.
43    pub tls_mode: TlsMode,
44    /// Connect-timeout ceiling (covers DNS + TCP + TLS + startup).
45    pub connect_timeout: Duration,
46    /// Per-query ceiling (round-trip from Query send to ReadyForQuery).
47    pub query_timeout: Duration,
48    /// Shared rustls `ClientConfig` — build once via
49    /// `super::tls::default_client_config` and reuse across connections.
50    pub tls_config: Arc<rustls::ClientConfig>,
51}
52
53impl BackendConfig {
54    pub fn address(&self) -> String {
55        format!("{}:{}", self.host, self.port)
56    }
57}
58
59/// An established, authenticated client connection to a backend.
60pub struct BackendClient {
61    stream: Stream,
62    /// Parameter values the server sent during startup (client_encoding,
63    /// server_version, TimeZone, …). Useful for diagnostics.
64    pub server_parameters: std::collections::HashMap<String, String>,
65    /// BackendKeyData cached for potential cancel requests.
66    pub backend_pid: Option<u32>,
67    pub backend_secret: Option<u32>,
68}
69
70impl BackendClient {
71    /// Connect, TLS-negotiate, authenticate, and drain server-initialisation
72    /// frames through `ReadyForQuery`. On success the client is idle and
73    /// ready to run SQL.
74    pub async fn connect(cfg: &BackendConfig) -> BackendResult<Self> {
75        tokio::time::timeout(cfg.connect_timeout, Self::connect_inner(cfg))
76            .await
77            .map_err(|_| BackendError::Io(std::io::Error::new(
78                std::io::ErrorKind::TimedOut,
79                format!("connect to {} exceeded {:?}", cfg.address(), cfg.connect_timeout),
80            )))?
81    }
82
83    async fn connect_inner(cfg: &BackendConfig) -> BackendResult<Self> {
84        let tcp = TcpStream::connect(cfg.address()).await?;
85        let mut stream =
86            negotiate(tcp, cfg.tls_mode, cfg.tls_config.clone(), &cfg.host).await?;
87
88        // Send StartupMessage (protocol 3.0).
89        let startup = build_startup(cfg);
90        stream.write_all(&startup).await?;
91
92        let mut server_parameters = std::collections::HashMap::new();
93        let mut backend_pid = None;
94        let mut backend_secret = None;
95        let mut buffer = BytesMut::with_capacity(4096);
96        let codec = ProtocolCodec::new();
97        let mut scram_state: Option<Scram> = None;
98
99        loop {
100            let msg = read_one(&mut stream, &mut buffer, &codec).await?;
101            match msg.msg_type {
102                MessageType::AuthRequest => {
103                    handle_auth(
104                        &mut stream,
105                        &msg,
106                        cfg,
107                        &mut scram_state,
108                    )
109                    .await?;
110                }
111                MessageType::ParameterStatus => {
112                    if let Some((k, v)) = parse_parameter_status(&msg.payload) {
113                        server_parameters.insert(k, v);
114                    }
115                }
116                MessageType::BackendKeyData => {
117                    if msg.payload.len() >= 8 {
118                        backend_pid = Some(u32::from_be_bytes(
119                            msg.payload[0..4].try_into().unwrap(),
120                        ));
121                        backend_secret = Some(u32::from_be_bytes(
122                            msg.payload[4..8].try_into().unwrap(),
123                        ));
124                    }
125                }
126                MessageType::ReadyForQuery => {
127                    return Ok(Self {
128                        stream,
129                        server_parameters,
130                        backend_pid,
131                        backend_secret,
132                    });
133                }
134                MessageType::ErrorResponse => {
135                    return Err(BackendError::BackendError(error_message(&msg.payload)));
136                }
137                MessageType::NoticeResponse => {
138                    // Ignore warnings during startup.
139                }
140                other => {
141                    return Err(BackendError::Protocol(format!(
142                        "unexpected message during startup: {:?}",
143                        other
144                    )));
145                }
146            }
147        }
148    }
149
150    /// Run a simple-query (SQL text) and collect every resulting row
151    /// into a `QueryResult`. For statements that don't return rows
152    /// (DDL, `SET`, etc.) `rows` will be empty and `command_tag`
153    /// carries the completion string.
154    pub async fn simple_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
155        self.run_query(sql).await
156    }
157
158    /// Like `simple_query` but substitutes `$1`, `$2`, … with
159    /// text-format literals before sending. We stick to simple-query
160    /// rather than the extended protocol to keep the surface narrow.
161    pub async fn query_with_params(
162        &mut self,
163        sql: &str,
164        params: &[ParamValue],
165    ) -> BackendResult<QueryResult> {
166        let substituted = interpolate_params(sql, params)?;
167        self.run_query(&substituted).await
168    }
169
170    /// Shorthand for a scalar lookup: runs `sql`, expects 1 column, 1 row.
171    pub async fn query_scalar(&mut self, sql: &str) -> BackendResult<TextValue> {
172        let res = self.simple_query(sql).await?;
173        if res.rows.len() != 1 {
174            return Err(BackendError::Protocol(format!(
175                "expected 1 row, got {}",
176                res.rows.len()
177            )));
178        }
179        if res.columns.len() != 1 {
180            return Err(BackendError::Protocol(format!(
181                "expected 1 column, got {}",
182                res.columns.len()
183            )));
184        }
185        Ok(res.rows.into_iter().next().unwrap().into_iter().next().unwrap())
186    }
187
188    /// Run a statement with no result set (DDL, SET, DO). Returns the
189    /// command tag (e.g. `"SET"`, `"CREATE TABLE"`).
190    pub async fn execute(&mut self, sql: &str) -> BackendResult<String> {
191        let res = self.simple_query(sql).await?;
192        Ok(res.command_tag)
193    }
194
195    async fn run_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
196        let t = self.stream_query_timeout();
197        tokio::time::timeout(t, Self::run_query_inner(&mut self.stream, sql))
198            .await
199            .map_err(|_| BackendError::Io(std::io::Error::new(
200                std::io::ErrorKind::TimedOut,
201                format!("query exceeded {:?}: {}", t, truncate(sql, 64)),
202            )))?
203    }
204
205    fn stream_query_timeout(&self) -> Duration {
206        // 30 seconds is a sane default for management queries; callers
207        // that need longer can wrap their own timeout around this call.
208        Duration::from_secs(30)
209    }
210
211    async fn run_query_inner(stream: &mut Stream, sql: &str) -> BackendResult<QueryResult> {
212        // Build and send a Query message (tag 'Q', payload = SQL + \0).
213        let mut payload = BytesMut::with_capacity(sql.len() + 1);
214        payload.extend_from_slice(sql.as_bytes());
215        payload.put_u8(0);
216        let frame = Message::new(MessageType::Query, payload).encode();
217        stream.write_all(&frame).await?;
218
219        let mut buffer = BytesMut::with_capacity(8192);
220        let codec = ProtocolCodec::new();
221        let mut columns: Vec<ColumnMeta> = Vec::new();
222        let mut rows: Vec<Vec<TextValue>> = Vec::new();
223        let mut command_tag = String::new();
224        let mut last_error: Option<String> = None;
225
226        loop {
227            let msg = read_one(stream, &mut buffer, &codec).await?;
228            match msg.msg_type {
229                // Both 'T' (RowDescription) and 'C' (CommandComplete)
230                // may appear; from_tag conflates T with… nothing. It
231                // DOES conflate D with Describe, but on a server→client
232                // frame the D tag here is always DataRow in practice,
233                // because we only arrive at run_query_inner after the
234                // startup handshake and the server never sends Describe.
235                MessageType::RowDescription => {
236                    columns = parse_row_description(&msg.payload);
237                }
238                MessageType::DataRow => {
239                    let row = parse_data_row(&msg.payload, columns.len())?;
240                    rows.push(row);
241                }
242                // PG tag 'C' = CommandComplete (server → client). The
243                // shared MessageType enum also has Close (client → server)
244                // under the same tag; again, the direction fixes the
245                // ambiguity at runtime.
246                MessageType::CommandComplete | MessageType::Close => {
247                    command_tag = parse_cstring(&msg.payload);
248                }
249                MessageType::EmptyQueryResponse => {
250                    command_tag = String::new();
251                }
252                MessageType::ErrorResponse => {
253                    last_error = Some(error_message(&msg.payload));
254                }
255                MessageType::NoticeResponse => {
256                    tracing::debug!(notice = %error_message(&msg.payload), "backend notice");
257                }
258                MessageType::ReadyForQuery => {
259                    if let Some(e) = last_error {
260                        return Err(BackendError::BackendError(e));
261                    }
262                    return Ok(QueryResult {
263                        columns,
264                        rows,
265                        command_tag,
266                    });
267                }
268                MessageType::ParameterStatus => {
269                    // Server may push parameter changes (e.g. after SET).
270                    // Ignore here; callers that care can call
271                    // simple_query("SHOW <param>") afterwards.
272                }
273                _other => {
274                    // Unknown message kind — keep draining until
275                    // ReadyForQuery. A well-behaved PG server won't
276                    // send anything strange outside of the above set.
277                }
278            }
279        }
280    }
281
282    /// Close the connection gracefully (send Terminate, close socket).
283    pub async fn close(mut self) {
284        let term = Message::new(MessageType::Terminate, BytesMut::new()).encode();
285        let _ = self.stream.write_all(&term).await;
286        let _ = self.stream.shutdown().await;
287    }
288
289    /// Report whether the underlying connection is over TLS.
290    pub fn is_tls(&self) -> bool {
291        self.stream.is_tls()
292    }
293}
294
295// ---------------------------------------------------------------------
296// Result shape
297// ---------------------------------------------------------------------
298
299#[derive(Debug, Clone)]
300pub struct ColumnMeta {
301    pub name: String,
302    pub type_oid: u32,
303}
304
305#[derive(Debug, Clone)]
306pub struct QueryResult {
307    pub columns: Vec<ColumnMeta>,
308    pub rows: Vec<Vec<TextValue>>,
309    pub command_tag: String,
310}
311
312impl QueryResult {
313    /// Return the numeric suffix of a CommandComplete tag, if any.
314    /// For example, `"INSERT 0 5"` → `Some(5)`.
315    pub fn rows_affected(&self) -> Option<u64> {
316        self.command_tag
317            .split_whitespace()
318            .last()
319            .and_then(|s| s.parse().ok())
320    }
321}
322
323// ---------------------------------------------------------------------
324// Helpers
325// ---------------------------------------------------------------------
326
327fn build_startup(cfg: &BackendConfig) -> Vec<u8> {
328    let mut payload = BytesMut::with_capacity(128);
329    // Protocol version 3.0.
330    payload.put_u32(196608);
331    put_cstring(&mut payload, "user");
332    put_cstring(&mut payload, &cfg.user);
333    if let Some(db) = &cfg.database {
334        put_cstring(&mut payload, "database");
335        put_cstring(&mut payload, db);
336    }
337    put_cstring(&mut payload, "application_name");
338    put_cstring(
339        &mut payload,
340        cfg.application_name
341            .as_deref()
342            .unwrap_or("heliosdb-proxy"),
343    );
344    put_cstring(&mut payload, "client_encoding");
345    put_cstring(&mut payload, "UTF8");
346    payload.put_u8(0); // terminator
347
348    let mut framed = BytesMut::with_capacity(payload.len() + 4);
349    framed.put_u32((payload.len() + 4) as u32);
350    framed.extend_from_slice(&payload);
351    framed.to_vec()
352}
353
354fn put_cstring(buf: &mut BytesMut, s: &str) {
355    buf.extend_from_slice(s.as_bytes());
356    buf.put_u8(0);
357}
358
359fn parse_cstring(payload: &[u8]) -> String {
360    let end = payload.iter().position(|&b| b == 0).unwrap_or(payload.len());
361    String::from_utf8_lossy(&payload[..end]).into_owned()
362}
363
364fn parse_parameter_status(payload: &[u8]) -> Option<(String, String)> {
365    let end1 = payload.iter().position(|&b| b == 0)?;
366    let key = String::from_utf8_lossy(&payload[..end1]).into_owned();
367    let rest = &payload[end1 + 1..];
368    let end2 = rest.iter().position(|&b| b == 0).unwrap_or(rest.len());
369    let value = String::from_utf8_lossy(&rest[..end2]).into_owned();
370    Some((key, value))
371}
372
373fn parse_row_description(payload: &[u8]) -> Vec<ColumnMeta> {
374    let mut p = BytesMut::from(payload);
375    if p.remaining() < 2 {
376        return Vec::new();
377    }
378    let n = p.get_u16() as usize;
379    let mut cols = Vec::with_capacity(n);
380    for _ in 0..n {
381        // cstring name
382        let end = match p.as_ref().iter().position(|&b| b == 0) {
383            Some(i) => i,
384            None => break,
385        };
386        let name = String::from_utf8_lossy(&p.as_ref()[..end]).into_owned();
387        p.advance(end + 1);
388        if p.remaining() < 18 {
389            break;
390        }
391        let _table_oid = p.get_u32();
392        let _column_number = p.get_u16();
393        let type_oid = p.get_u32();
394        let _type_len = p.get_i16();
395        let _type_mod = p.get_i32();
396        let _format_code = p.get_u16();
397        cols.push(ColumnMeta { name, type_oid });
398    }
399    cols
400}
401
402fn parse_data_row(payload: &[u8], column_count: usize) -> BackendResult<Vec<TextValue>> {
403    let mut p = BytesMut::from(payload);
404    if p.remaining() < 2 {
405        return Err(BackendError::Protocol("truncated DataRow".into()));
406    }
407    let n = p.get_u16() as usize;
408    let mut out = Vec::with_capacity(n);
409    for _ in 0..n {
410        if p.remaining() < 4 {
411            return Err(BackendError::Protocol("truncated DataRow field".into()));
412        }
413        let len = p.get_i32();
414        if len == -1 {
415            out.push(TextValue::Null);
416        } else {
417            let len = len as usize;
418            if p.remaining() < len {
419                return Err(BackendError::Protocol(
420                    "truncated DataRow value".into(),
421                ));
422            }
423            let bytes = p.split_to(len);
424            out.push(TextValue::Text(
425                String::from_utf8_lossy(&bytes).into_owned(),
426            ));
427        }
428    }
429    let _ = column_count;
430    Ok(out)
431}
432
433fn error_message(payload: &[u8]) -> String {
434    // ErrorResponse fields: { code:u8; cstring }*, terminated by code=0.
435    // The M (message) field is mandatory.
436    let mut i = 0;
437    let mut msg_field = None;
438    while i < payload.len() {
439        let code = payload[i];
440        if code == 0 {
441            break;
442        }
443        i += 1;
444        let end = match payload[i..].iter().position(|&b| b == 0) {
445            Some(e) => i + e,
446            None => payload.len(),
447        };
448        let value = String::from_utf8_lossy(&payload[i..end]).into_owned();
449        if code == b'M' {
450            msg_field = Some(value);
451        }
452        i = end + 1;
453    }
454    msg_field.unwrap_or_else(|| "<no message>".to_string())
455}
456
457async fn read_one(
458    stream: &mut Stream,
459    buffer: &mut BytesMut,
460    codec: &ProtocolCodec,
461) -> BackendResult<Message> {
462    loop {
463        if let Some(msg) = codec
464            .decode_message(buffer)
465            .map_err(|e| BackendError::Protocol(e.to_string()))?
466        {
467            return Ok(msg);
468        }
469        let mut tmp = vec![0u8; 4096];
470        let n = stream.read(&mut tmp).await?;
471        if n == 0 {
472            return Err(BackendError::Closed);
473        }
474        buffer.extend_from_slice(&tmp[..n]);
475    }
476}
477
478async fn handle_auth(
479    stream: &mut Stream,
480    msg: &Message,
481    cfg: &BackendConfig,
482    scram_state: &mut Option<Scram>,
483) -> BackendResult<()> {
484    if msg.payload.len() < 4 {
485        return Err(BackendError::Protocol(
486            "AuthRequest payload < 4 bytes".into(),
487        ));
488    }
489    let code =
490        u32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
491    match code {
492        0 => Ok(()), // AuthenticationOk
493        5 => {
494            // AuthenticationMD5Password + 4-byte salt.
495            if msg.payload.len() < 8 {
496                return Err(BackendError::Protocol("AuthenticationMD5 truncated".into()));
497            }
498            let salt: [u8; 4] = [
499                msg.payload[4],
500                msg.payload[5],
501                msg.payload[6],
502                msg.payload[7],
503            ];
504            let password = cfg.password.as_deref().ok_or_else(|| {
505                BackendError::Auth("server requested MD5 but no password configured".into())
506            })?;
507            let payload = md5_password_response(&cfg.user, password, &salt);
508            write_password_message(stream, &payload).await
509        }
510        3 => {
511            // AuthenticationCleartextPassword — plain password with null terminator.
512            let password = cfg.password.as_deref().ok_or_else(|| {
513                BackendError::Auth("server requested password but none configured".into())
514            })?;
515            let mut payload = Vec::with_capacity(password.len() + 1);
516            payload.extend_from_slice(password.as_bytes());
517            payload.push(0);
518            write_password_message(stream, &payload).await
519        }
520        10 => {
521            // AuthenticationSASL: list of mechanism cstrings, then 0.
522            let mechs = parse_sasl_mechanisms(&msg.payload[4..]);
523            if !mechs.iter().any(|m| m == "SCRAM-SHA-256") {
524                return Err(BackendError::Auth(format!(
525                    "no supported SASL mechanism; server offered {:?}",
526                    mechs
527                )));
528            }
529            let nonce = generate_nonce();
530            let (scram, first) = Scram::client_first(nonce);
531            *scram_state = Some(scram);
532            write_password_message(stream, &first.0).await
533        }
534        11 => {
535            // AuthenticationSASLContinue: server-first bytes.
536            let scram = scram_state.as_mut().ok_or_else(|| {
537                BackendError::Auth("SASLContinue before SASL start".into())
538            })?;
539            let password = cfg.password.as_deref().ok_or_else(|| {
540                BackendError::Auth("SCRAM requires a password".into())
541            })?;
542            let out = scram.client_final(&msg.payload[4..], password)?;
543            write_password_message(stream, &out.0).await
544        }
545        12 => {
546            // AuthenticationSASLFinal: v=<signature>
547            let scram = scram_state.as_ref().ok_or_else(|| {
548                BackendError::Auth("SASLFinal before SASL start".into())
549            })?;
550            scram.verify_server(&msg.payload[4..])
551        }
552        other => Err(BackendError::Auth(format!(
553            "unsupported authentication request code: {}",
554            other
555        ))),
556    }
557}
558
559async fn write_password_message(
560    stream: &mut Stream,
561    payload: &[u8],
562) -> BackendResult<()> {
563    let mut buf = BytesMut::with_capacity(payload.len() + 5);
564    buf.put_u8(b'p');
565    buf.put_u32((payload.len() + 4) as u32);
566    buf.extend_from_slice(payload);
567    stream.write_all(&buf).await?;
568    Ok(())
569}
570
571fn parse_sasl_mechanisms(payload: &[u8]) -> Vec<String> {
572    let mut out = Vec::new();
573    let mut i = 0;
574    while i < payload.len() {
575        let end = match payload[i..].iter().position(|&b| b == 0) {
576            Some(e) => i + e,
577            None => payload.len(),
578        };
579        if end == i {
580            break; // list terminator
581        }
582        out.push(String::from_utf8_lossy(&payload[i..end]).into_owned());
583        i = end + 1;
584    }
585    out
586}
587
588fn generate_nonce() -> String {
589    use base64::Engine as _;
590    use rand::RngCore;
591    let mut bytes = [0u8; 18];
592    rand::thread_rng().fill_bytes(&mut bytes);
593    // URL-safe base64 without padding keeps it ASCII.
594    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
595}
596
597fn interpolate_params(sql: &str, params: &[ParamValue]) -> BackendResult<String> {
598    // Walk the SQL, replacing $N tokens outside of string literals with
599    // the encoded parameter. This is a deliberately simple interpolator
600    // for internal management queries; it does NOT try to be a full
601    // PG parser.
602    let mut out = String::with_capacity(sql.len());
603    let bytes = sql.as_bytes();
604    let mut i = 0;
605    let mut in_string = false;
606    let mut quote_char = 0u8;
607    while i < bytes.len() {
608        let b = bytes[i];
609        if in_string {
610            out.push(b as char);
611            if b == quote_char {
612                // PG doubles the quote to escape; peek ahead.
613                if i + 1 < bytes.len() && bytes[i + 1] == quote_char {
614                    out.push(quote_char as char);
615                    i += 2;
616                    continue;
617                }
618                in_string = false;
619            }
620            i += 1;
621            continue;
622        }
623        if b == b'\'' || b == b'"' {
624            in_string = true;
625            quote_char = b;
626            out.push(b as char);
627            i += 1;
628            continue;
629        }
630        if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
631            let mut j = i + 1;
632            while j < bytes.len() && bytes[j].is_ascii_digit() {
633                j += 1;
634            }
635            let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
636                .unwrap()
637                .parse()
638                .map_err(|_| {
639                    BackendError::Protocol(format!(
640                        "invalid parameter reference at byte {}",
641                        i
642                    ))
643                })?;
644            if idx == 0 || idx > params.len() {
645                return Err(BackendError::Protocol(format!(
646                    "parameter ${} out of range (have {})",
647                    idx,
648                    params.len()
649                )));
650            }
651            out.push_str(&encode_literal(&params[idx - 1]));
652            i = j;
653            continue;
654        }
655        out.push(b as char);
656        i += 1;
657    }
658    Ok(out)
659}
660
661fn truncate(s: &str, n: usize) -> &str {
662    match s.char_indices().nth(n) {
663        Some((i, _)) => &s[..i],
664        None => s,
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671    use crate::backend::types::ParamValue;
672
673    #[test]
674    fn test_build_startup_has_user_and_protocol_version() {
675        let cfg = BackendConfig {
676            host: "localhost".into(),
677            port: 5432,
678            user: "alice".into(),
679            password: None,
680            database: Some("app".into()),
681            application_name: None,
682            tls_mode: TlsMode::Disable,
683            connect_timeout: Duration::from_secs(5),
684            query_timeout: Duration::from_secs(5),
685            tls_config: crate::backend::tls::default_client_config(),
686        };
687        let bytes = build_startup(&cfg);
688        // First 4 bytes: length, next 4: protocol version 196608 (3 << 16).
689        assert_eq!(
690            u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
691            196608
692        );
693        assert!(bytes
694            .windows(5)
695            .any(|w| w == b"user\0"));
696        assert!(bytes
697            .windows(10)
698            .any(|w| w == b"database\0a"));
699    }
700
701    #[test]
702    fn test_interpolate_params_basic() {
703        let params = vec![
704            ParamValue::Int(42),
705            ParamValue::Text("alice".into()),
706        ];
707        let sql = "SELECT * FROM t WHERE id = $1 AND name = $2";
708        let out = interpolate_params(sql, &params).unwrap();
709        assert_eq!(out, "SELECT * FROM t WHERE id = 42 AND name = 'alice'");
710    }
711
712    #[test]
713    fn test_interpolate_params_escapes_quotes() {
714        let params = vec![ParamValue::Text("o'brien".into())];
715        let out =
716            interpolate_params("SELECT * FROM t WHERE name = $1", &params).unwrap();
717        assert_eq!(out, "SELECT * FROM t WHERE name = 'o''brien'");
718    }
719
720    #[test]
721    fn test_interpolate_params_leaves_dollar_in_string_alone() {
722        let params = vec![ParamValue::Int(1)];
723        let sql = "SELECT '$1' AS lit, $1 AS val";
724        let out = interpolate_params(sql, &params).unwrap();
725        assert_eq!(out, "SELECT '$1' AS lit, 1 AS val");
726    }
727
728    #[test]
729    fn test_interpolate_params_out_of_range() {
730        let params = vec![ParamValue::Int(1)];
731        let err = interpolate_params("SELECT $2", &params).unwrap_err();
732        assert!(matches!(err, BackendError::Protocol(_)));
733    }
734
735    #[test]
736    fn test_parse_row_description_shape() {
737        // numFields=1, name="x"\0, tableOid=0, col#=0, typeOid=23, len=4, mod=-1, fmt=0
738        let mut p = BytesMut::new();
739        p.put_u16(1);
740        p.extend_from_slice(b"x");
741        p.put_u8(0);
742        p.put_u32(0); // table oid
743        p.put_u16(0); // col #
744        p.put_u32(23); // int4
745        p.put_i16(4);
746        p.put_i32(-1);
747        p.put_u16(0);
748        let cols = parse_row_description(&p);
749        assert_eq!(cols.len(), 1);
750        assert_eq!(cols[0].name, "x");
751        assert_eq!(cols[0].type_oid, 23);
752    }
753
754    #[test]
755    fn test_parse_data_row_with_null() {
756        // numFields=2, len=1/'a', len=-1 (NULL)
757        let mut p = BytesMut::new();
758        p.put_u16(2);
759        p.put_i32(1);
760        p.extend_from_slice(b"a");
761        p.put_i32(-1);
762        let row = parse_data_row(&p, 2).unwrap();
763        assert_eq!(row.len(), 2);
764        assert_eq!(row[0], TextValue::Text("a".into()));
765        assert_eq!(row[1], TextValue::Null);
766    }
767
768    #[test]
769    fn test_error_message_extracts_m_field() {
770        let mut p = Vec::new();
771        p.push(b'S');
772        p.extend_from_slice(b"ERROR\0");
773        p.push(b'C');
774        p.extend_from_slice(b"28P01\0");
775        p.push(b'M');
776        p.extend_from_slice(b"password authentication failed\0");
777        p.push(0);
778        assert_eq!(error_message(&p), "password authentication failed");
779    }
780
781    #[test]
782    fn test_parse_parameter_status() {
783        let mut p = Vec::new();
784        p.extend_from_slice(b"client_encoding\0");
785        p.extend_from_slice(b"UTF8\0");
786        let (k, v) = parse_parameter_status(&p).unwrap();
787        assert_eq!(k, "client_encoding");
788        assert_eq!(v, "UTF8");
789    }
790
791    #[test]
792    fn test_parse_sasl_mechanisms() {
793        let mut p = Vec::new();
794        p.extend_from_slice(b"SCRAM-SHA-256\0");
795        p.extend_from_slice(b"SCRAM-SHA-256-PLUS\0");
796        p.push(0);
797        let m = parse_sasl_mechanisms(&p);
798        assert_eq!(m.len(), 2);
799        assert_eq!(m[0], "SCRAM-SHA-256");
800        assert_eq!(m[1], "SCRAM-SHA-256-PLUS");
801    }
802
803    #[test]
804    fn test_generate_nonce_is_url_safe() {
805        let n = generate_nonce();
806        assert!(n.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
807        assert!(n.len() >= 18);
808    }
809
810    #[test]
811    fn test_query_result_rows_affected() {
812        let r = QueryResult {
813            columns: Vec::new(),
814            rows: Vec::new(),
815            command_tag: "INSERT 0 5".into(),
816        };
817        assert_eq!(r.rows_affected(), Some(5));
818        let r = QueryResult {
819            columns: Vec::new(),
820            rows: Vec::new(),
821            command_tag: "SET".into(),
822        };
823        assert_eq!(r.rows_affected(), None);
824    }
825}