1use super::auth::{md5_password_response, Scram};
14use super::error::{BackendError, BackendResult};
15use super::stream::Stream;
16use super::tls::{negotiate, TlsMode};
17use super::types::{encode_literal, ParamValue, TextValue};
18use crate::protocol::{Message, MessageType, ProtocolCodec};
19use bytes::{Buf, BufMut, BytesMut};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::TcpStream;
24
25#[derive(Debug, Clone)]
27pub struct BackendConfig {
28 pub host: String,
30 pub port: u16,
32 pub user: String,
34 pub password: Option<String>,
37 pub database: Option<String>,
39 pub application_name: Option<String>,
42 pub tls_mode: TlsMode,
44 pub connect_timeout: Duration,
46 pub query_timeout: Duration,
48 pub tls_config: Arc<rustls::ClientConfig>,
51}
52
53impl BackendConfig {
54 pub fn address(&self) -> String {
55 format!("{}:{}", self.host, self.port)
56 }
57}
58
59pub struct BackendClient {
61 stream: Stream,
62 pub server_parameters: std::collections::HashMap<String, String>,
65 pub backend_pid: Option<u32>,
67 pub backend_secret: Option<u32>,
68}
69
70impl BackendClient {
71 pub async fn connect(cfg: &BackendConfig) -> BackendResult<Self> {
75 tokio::time::timeout(cfg.connect_timeout, Self::connect_inner(cfg))
76 .await
77 .map_err(|_| BackendError::Io(std::io::Error::new(
78 std::io::ErrorKind::TimedOut,
79 format!("connect to {} exceeded {:?}", cfg.address(), cfg.connect_timeout),
80 )))?
81 }
82
83 async fn connect_inner(cfg: &BackendConfig) -> BackendResult<Self> {
84 let tcp = TcpStream::connect(cfg.address()).await?;
85 let mut stream =
86 negotiate(tcp, cfg.tls_mode, cfg.tls_config.clone(), &cfg.host).await?;
87
88 let startup = build_startup(cfg);
90 stream.write_all(&startup).await?;
91
92 let mut server_parameters = std::collections::HashMap::new();
93 let mut backend_pid = None;
94 let mut backend_secret = None;
95 let mut buffer = BytesMut::with_capacity(4096);
96 let codec = ProtocolCodec::new();
97 let mut scram_state: Option<Scram> = None;
98
99 loop {
100 let msg = read_one(&mut stream, &mut buffer, &codec).await?;
101 match msg.msg_type {
102 MessageType::AuthRequest => {
103 handle_auth(
104 &mut stream,
105 &msg,
106 cfg,
107 &mut scram_state,
108 )
109 .await?;
110 }
111 MessageType::ParameterStatus => {
112 if let Some((k, v)) = parse_parameter_status(&msg.payload) {
113 server_parameters.insert(k, v);
114 }
115 }
116 MessageType::BackendKeyData => {
117 if msg.payload.len() >= 8 {
118 backend_pid = Some(u32::from_be_bytes(
119 msg.payload[0..4].try_into().unwrap(),
120 ));
121 backend_secret = Some(u32::from_be_bytes(
122 msg.payload[4..8].try_into().unwrap(),
123 ));
124 }
125 }
126 MessageType::ReadyForQuery => {
127 return Ok(Self {
128 stream,
129 server_parameters,
130 backend_pid,
131 backend_secret,
132 });
133 }
134 MessageType::ErrorResponse => {
135 return Err(BackendError::BackendError(error_message(&msg.payload)));
136 }
137 MessageType::NoticeResponse => {
138 }
140 other => {
141 return Err(BackendError::Protocol(format!(
142 "unexpected message during startup: {:?}",
143 other
144 )));
145 }
146 }
147 }
148 }
149
150 pub async fn simple_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
155 self.run_query(sql).await
156 }
157
158 pub async fn query_with_params(
162 &mut self,
163 sql: &str,
164 params: &[ParamValue],
165 ) -> BackendResult<QueryResult> {
166 let substituted = interpolate_params(sql, params)?;
167 self.run_query(&substituted).await
168 }
169
170 pub async fn query_scalar(&mut self, sql: &str) -> BackendResult<TextValue> {
172 let res = self.simple_query(sql).await?;
173 if res.rows.len() != 1 {
174 return Err(BackendError::Protocol(format!(
175 "expected 1 row, got {}",
176 res.rows.len()
177 )));
178 }
179 if res.columns.len() != 1 {
180 return Err(BackendError::Protocol(format!(
181 "expected 1 column, got {}",
182 res.columns.len()
183 )));
184 }
185 Ok(res.rows.into_iter().next().unwrap().into_iter().next().unwrap())
186 }
187
188 pub async fn execute(&mut self, sql: &str) -> BackendResult<String> {
191 let res = self.simple_query(sql).await?;
192 Ok(res.command_tag)
193 }
194
195 pub async fn copy_in(&mut self, copy_sql: &str, data: &[u8]) -> BackendResult<u64> {
204 let t = Duration::from_secs(600);
207 tokio::time::timeout(t, Self::copy_in_inner(&mut self.stream, copy_sql, data))
208 .await
209 .map_err(|_| {
210 BackendError::Io(std::io::Error::new(
211 std::io::ErrorKind::TimedOut,
212 format!("COPY exceeded {:?}", t),
213 ))
214 })?
215 }
216
217 async fn copy_in_inner(stream: &mut Stream, copy_sql: &str, data: &[u8]) -> BackendResult<u64> {
218 let mut payload = BytesMut::with_capacity(copy_sql.len() + 1);
220 payload.extend_from_slice(copy_sql.as_bytes());
221 payload.put_u8(0);
222 stream
223 .write_all(&Message::new(MessageType::Query, payload).encode())
224 .await?;
225
226 let mut buffer = BytesMut::with_capacity(8192);
227 let codec = ProtocolCodec::new();
228
229 loop {
234 let msg = read_one(stream, &mut buffer, &codec).await?;
235 match msg.msg_type {
236 MessageType::Unknown(b'G') => break,
237 MessageType::ErrorResponse => {
238 let e = error_message(&msg.payload);
239 Self::drain_to_ready(stream, &mut buffer, &codec).await?;
240 return Err(BackendError::BackendError(e));
241 }
242 MessageType::ReadyForQuery => {
243 return Err(BackendError::Protocol(
244 "COPY: ReadyForQuery before CopyInResponse".into(),
245 ));
246 }
247 _ => {} }
249 }
250
251 const CHUNK: usize = 64 * 1024;
253 let mut off = 0;
254 while off < data.len() {
255 let end = (off + CHUNK).min(data.len());
256 let mut p = BytesMut::with_capacity(end - off);
257 p.extend_from_slice(&data[off..end]);
258 stream
259 .write_all(&Message::new(MessageType::CopyData, p).encode())
260 .await?;
261 off = end;
262 }
263 stream
264 .write_all(&Message::new(MessageType::CopyDone, BytesMut::new()).encode())
265 .await?;
266
267 let mut tag = String::new();
269 let mut last_error = None;
270 loop {
271 let msg = read_one(stream, &mut buffer, &codec).await?;
272 match msg.msg_type {
273 MessageType::CommandComplete | MessageType::Close => {
274 tag = parse_cstring(&msg.payload);
275 }
276 MessageType::ErrorResponse => {
277 last_error = Some(error_message(&msg.payload));
278 }
279 MessageType::ReadyForQuery => {
280 if let Some(e) = last_error {
281 return Err(BackendError::BackendError(e));
282 }
283 let n = tag
285 .rsplit(' ')
286 .next()
287 .and_then(|s| s.parse::<u64>().ok())
288 .unwrap_or(0);
289 return Ok(n);
290 }
291 _ => {}
292 }
293 }
294 }
295
296 async fn drain_to_ready(
297 stream: &mut Stream,
298 buffer: &mut BytesMut,
299 codec: &ProtocolCodec,
300 ) -> BackendResult<()> {
301 loop {
302 if read_one(stream, buffer, codec).await?.msg_type == MessageType::ReadyForQuery {
303 return Ok(());
304 }
305 }
306 }
307
308 async fn run_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
309 let t = self.stream_query_timeout();
310 tokio::time::timeout(t, Self::run_query_inner(&mut self.stream, sql))
311 .await
312 .map_err(|_| BackendError::Io(std::io::Error::new(
313 std::io::ErrorKind::TimedOut,
314 format!("query exceeded {:?}: {}", t, truncate(sql, 64)),
315 )))?
316 }
317
318 fn stream_query_timeout(&self) -> Duration {
319 Duration::from_secs(30)
322 }
323
324 async fn run_query_inner(stream: &mut Stream, sql: &str) -> BackendResult<QueryResult> {
325 let mut payload = BytesMut::with_capacity(sql.len() + 1);
327 payload.extend_from_slice(sql.as_bytes());
328 payload.put_u8(0);
329 let frame = Message::new(MessageType::Query, payload).encode();
330 stream.write_all(&frame).await?;
331
332 let mut buffer = BytesMut::with_capacity(8192);
333 let codec = ProtocolCodec::new();
334 let mut columns: Vec<ColumnMeta> = Vec::new();
335 let mut rows: Vec<Vec<TextValue>> = Vec::new();
336 let mut command_tag = String::new();
337 let mut last_error: Option<String> = None;
338
339 loop {
340 let msg = read_one(stream, &mut buffer, &codec).await?;
341 match msg.msg_type {
342 MessageType::RowDescription => {
349 columns = parse_row_description(&msg.payload);
350 }
351 MessageType::DataRow => {
352 let row = parse_data_row(&msg.payload, columns.len())?;
353 rows.push(row);
354 }
355 MessageType::CommandComplete | MessageType::Close => {
360 command_tag = parse_cstring(&msg.payload);
361 }
362 MessageType::EmptyQueryResponse => {
363 command_tag = String::new();
364 }
365 MessageType::ErrorResponse => {
366 last_error = Some(error_message(&msg.payload));
367 }
368 MessageType::NoticeResponse => {
369 tracing::debug!(notice = %error_message(&msg.payload), "backend notice");
370 }
371 MessageType::ReadyForQuery => {
372 if let Some(e) = last_error {
373 return Err(BackendError::BackendError(e));
374 }
375 return Ok(QueryResult {
376 columns,
377 rows,
378 command_tag,
379 });
380 }
381 MessageType::ParameterStatus => {
382 }
386 _other => {
387 }
391 }
392 }
393 }
394
395 pub async fn close(mut self) {
397 let term = Message::new(MessageType::Terminate, BytesMut::new()).encode();
398 let _ = self.stream.write_all(&term).await;
399 let _ = self.stream.shutdown().await;
400 }
401
402 pub fn is_tls(&self) -> bool {
404 self.stream.is_tls()
405 }
406}
407
408#[derive(Debug, Clone)]
413pub struct ColumnMeta {
414 pub name: String,
415 pub type_oid: u32,
416}
417
418#[derive(Debug, Clone)]
419pub struct QueryResult {
420 pub columns: Vec<ColumnMeta>,
421 pub rows: Vec<Vec<TextValue>>,
422 pub command_tag: String,
423}
424
425impl QueryResult {
426 pub fn rows_affected(&self) -> Option<u64> {
429 self.command_tag
430 .split_whitespace()
431 .last()
432 .and_then(|s| s.parse().ok())
433 }
434}
435
436fn build_startup(cfg: &BackendConfig) -> Vec<u8> {
441 let mut payload = BytesMut::with_capacity(128);
442 payload.put_u32(196608);
444 put_cstring(&mut payload, "user");
445 put_cstring(&mut payload, &cfg.user);
446 if let Some(db) = &cfg.database {
447 put_cstring(&mut payload, "database");
448 put_cstring(&mut payload, db);
449 }
450 put_cstring(&mut payload, "application_name");
451 put_cstring(
452 &mut payload,
453 cfg.application_name
454 .as_deref()
455 .unwrap_or("heliosdb-proxy"),
456 );
457 put_cstring(&mut payload, "client_encoding");
458 put_cstring(&mut payload, "UTF8");
459 payload.put_u8(0); let mut framed = BytesMut::with_capacity(payload.len() + 4);
462 framed.put_u32((payload.len() + 4) as u32);
463 framed.extend_from_slice(&payload);
464 framed.to_vec()
465}
466
467fn put_cstring(buf: &mut BytesMut, s: &str) {
468 buf.extend_from_slice(s.as_bytes());
469 buf.put_u8(0);
470}
471
472fn parse_cstring(payload: &[u8]) -> String {
473 let end = payload.iter().position(|&b| b == 0).unwrap_or(payload.len());
474 String::from_utf8_lossy(&payload[..end]).into_owned()
475}
476
477fn parse_parameter_status(payload: &[u8]) -> Option<(String, String)> {
478 let end1 = payload.iter().position(|&b| b == 0)?;
479 let key = String::from_utf8_lossy(&payload[..end1]).into_owned();
480 let rest = &payload[end1 + 1..];
481 let end2 = rest.iter().position(|&b| b == 0).unwrap_or(rest.len());
482 let value = String::from_utf8_lossy(&rest[..end2]).into_owned();
483 Some((key, value))
484}
485
486fn parse_row_description(payload: &[u8]) -> Vec<ColumnMeta> {
487 let mut p = BytesMut::from(payload);
488 if p.remaining() < 2 {
489 return Vec::new();
490 }
491 let n = p.get_u16() as usize;
492 let mut cols = Vec::with_capacity(n);
493 for _ in 0..n {
494 let end = match p.as_ref().iter().position(|&b| b == 0) {
496 Some(i) => i,
497 None => break,
498 };
499 let name = String::from_utf8_lossy(&p.as_ref()[..end]).into_owned();
500 p.advance(end + 1);
501 if p.remaining() < 18 {
502 break;
503 }
504 let _table_oid = p.get_u32();
505 let _column_number = p.get_u16();
506 let type_oid = p.get_u32();
507 let _type_len = p.get_i16();
508 let _type_mod = p.get_i32();
509 let _format_code = p.get_u16();
510 cols.push(ColumnMeta { name, type_oid });
511 }
512 cols
513}
514
515fn parse_data_row(payload: &[u8], column_count: usize) -> BackendResult<Vec<TextValue>> {
516 let mut p = BytesMut::from(payload);
517 if p.remaining() < 2 {
518 return Err(BackendError::Protocol("truncated DataRow".into()));
519 }
520 let n = p.get_u16() as usize;
521 let mut out = Vec::with_capacity(n);
522 for _ in 0..n {
523 if p.remaining() < 4 {
524 return Err(BackendError::Protocol("truncated DataRow field".into()));
525 }
526 let len = p.get_i32();
527 if len == -1 {
528 out.push(TextValue::Null);
529 } else {
530 let len = len as usize;
531 if p.remaining() < len {
532 return Err(BackendError::Protocol(
533 "truncated DataRow value".into(),
534 ));
535 }
536 let bytes = p.split_to(len);
537 out.push(TextValue::Text(
538 String::from_utf8_lossy(&bytes).into_owned(),
539 ));
540 }
541 }
542 let _ = column_count;
543 Ok(out)
544}
545
546fn error_message(payload: &[u8]) -> String {
547 let mut i = 0;
550 let mut msg_field = None;
551 while i < payload.len() {
552 let code = payload[i];
553 if code == 0 {
554 break;
555 }
556 i += 1;
557 let end = match payload[i..].iter().position(|&b| b == 0) {
558 Some(e) => i + e,
559 None => payload.len(),
560 };
561 let value = String::from_utf8_lossy(&payload[i..end]).into_owned();
562 if code == b'M' {
563 msg_field = Some(value);
564 }
565 i = end + 1;
566 }
567 msg_field.unwrap_or_else(|| "<no message>".to_string())
568}
569
570async fn read_one(
571 stream: &mut Stream,
572 buffer: &mut BytesMut,
573 codec: &ProtocolCodec,
574) -> BackendResult<Message> {
575 loop {
576 if let Some(mut msg) = codec
577 .decode_message(buffer)
578 .map_err(|e| BackendError::Protocol(e.to_string()))?
579 {
580 msg.msg_type = match msg.msg_type {
587 MessageType::Sync => MessageType::ParameterStatus,
588 MessageType::Describe => MessageType::DataRow,
589 MessageType::Execute => MessageType::ErrorResponse,
590 MessageType::Close => MessageType::CommandComplete,
591 other => other,
592 };
593 return Ok(msg);
594 }
595 let mut tmp = vec![0u8; 4096];
596 let n = stream.read(&mut tmp).await?;
597 if n == 0 {
598 return Err(BackendError::Closed);
599 }
600 buffer.extend_from_slice(&tmp[..n]);
601 }
602}
603
604async fn handle_auth(
605 stream: &mut Stream,
606 msg: &Message,
607 cfg: &BackendConfig,
608 scram_state: &mut Option<Scram>,
609) -> BackendResult<()> {
610 if msg.payload.len() < 4 {
611 return Err(BackendError::Protocol(
612 "AuthRequest payload < 4 bytes".into(),
613 ));
614 }
615 let code =
616 u32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
617 match code {
618 0 => Ok(()), 5 => {
620 if msg.payload.len() < 8 {
622 return Err(BackendError::Protocol("AuthenticationMD5 truncated".into()));
623 }
624 let salt: [u8; 4] = [
625 msg.payload[4],
626 msg.payload[5],
627 msg.payload[6],
628 msg.payload[7],
629 ];
630 let password = cfg.password.as_deref().ok_or_else(|| {
631 BackendError::Auth("server requested MD5 but no password configured".into())
632 })?;
633 let payload = md5_password_response(&cfg.user, password, &salt);
634 write_password_message(stream, &payload).await
635 }
636 3 => {
637 let password = cfg.password.as_deref().ok_or_else(|| {
639 BackendError::Auth("server requested password but none configured".into())
640 })?;
641 let mut payload = Vec::with_capacity(password.len() + 1);
642 payload.extend_from_slice(password.as_bytes());
643 payload.push(0);
644 write_password_message(stream, &payload).await
645 }
646 10 => {
647 let mechs = parse_sasl_mechanisms(&msg.payload[4..]);
649 if !mechs.iter().any(|m| m == "SCRAM-SHA-256") {
650 return Err(BackendError::Auth(format!(
651 "no supported SASL mechanism; server offered {:?}",
652 mechs
653 )));
654 }
655 let nonce = generate_nonce();
656 let (scram, first) = Scram::client_first(nonce);
657 *scram_state = Some(scram);
658 write_password_message(stream, &first.0).await
659 }
660 11 => {
661 let scram = scram_state.as_mut().ok_or_else(|| {
663 BackendError::Auth("SASLContinue before SASL start".into())
664 })?;
665 let password = cfg.password.as_deref().ok_or_else(|| {
666 BackendError::Auth("SCRAM requires a password".into())
667 })?;
668 let out = scram.client_final(&msg.payload[4..], password)?;
669 write_password_message(stream, &out.0).await
670 }
671 12 => {
672 let scram = scram_state.as_ref().ok_or_else(|| {
674 BackendError::Auth("SASLFinal before SASL start".into())
675 })?;
676 scram.verify_server(&msg.payload[4..])
677 }
678 other => Err(BackendError::Auth(format!(
679 "unsupported authentication request code: {}",
680 other
681 ))),
682 }
683}
684
685async fn write_password_message(
686 stream: &mut Stream,
687 payload: &[u8],
688) -> BackendResult<()> {
689 let mut buf = BytesMut::with_capacity(payload.len() + 5);
690 buf.put_u8(b'p');
691 buf.put_u32((payload.len() + 4) as u32);
692 buf.extend_from_slice(payload);
693 stream.write_all(&buf).await?;
694 Ok(())
695}
696
697fn parse_sasl_mechanisms(payload: &[u8]) -> Vec<String> {
698 let mut out = Vec::new();
699 let mut i = 0;
700 while i < payload.len() {
701 let end = match payload[i..].iter().position(|&b| b == 0) {
702 Some(e) => i + e,
703 None => payload.len(),
704 };
705 if end == i {
706 break; }
708 out.push(String::from_utf8_lossy(&payload[i..end]).into_owned());
709 i = end + 1;
710 }
711 out
712}
713
714fn generate_nonce() -> String {
715 use base64::Engine as _;
716 use rand::RngCore;
717 let mut bytes = [0u8; 18];
718 rand::thread_rng().fill_bytes(&mut bytes);
719 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
721}
722
723fn interpolate_params(sql: &str, params: &[ParamValue]) -> BackendResult<String> {
724 let mut out = String::with_capacity(sql.len());
729 let bytes = sql.as_bytes();
730 let mut i = 0;
731 let mut in_string = false;
732 let mut quote_char = 0u8;
733 while i < bytes.len() {
734 let b = bytes[i];
735 if in_string {
736 out.push(b as char);
737 if b == quote_char {
738 if i + 1 < bytes.len() && bytes[i + 1] == quote_char {
740 out.push(quote_char as char);
741 i += 2;
742 continue;
743 }
744 in_string = false;
745 }
746 i += 1;
747 continue;
748 }
749 if b == b'\'' || b == b'"' {
750 in_string = true;
751 quote_char = b;
752 out.push(b as char);
753 i += 1;
754 continue;
755 }
756 if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
757 let mut j = i + 1;
758 while j < bytes.len() && bytes[j].is_ascii_digit() {
759 j += 1;
760 }
761 let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
762 .unwrap()
763 .parse()
764 .map_err(|_| {
765 BackendError::Protocol(format!(
766 "invalid parameter reference at byte {}",
767 i
768 ))
769 })?;
770 if idx == 0 || idx > params.len() {
771 return Err(BackendError::Protocol(format!(
772 "parameter ${} out of range (have {})",
773 idx,
774 params.len()
775 )));
776 }
777 out.push_str(&encode_literal(¶ms[idx - 1]));
778 i = j;
779 continue;
780 }
781 out.push(b as char);
782 i += 1;
783 }
784 Ok(out)
785}
786
787fn truncate(s: &str, n: usize) -> &str {
788 match s.char_indices().nth(n) {
789 Some((i, _)) => &s[..i],
790 None => s,
791 }
792}
793
794#[cfg(test)]
795mod tests {
796 use super::*;
797 use crate::backend::types::ParamValue;
798
799 #[test]
800 fn test_build_startup_has_user_and_protocol_version() {
801 let cfg = BackendConfig {
802 host: "localhost".into(),
803 port: 5432,
804 user: "alice".into(),
805 password: None,
806 database: Some("app".into()),
807 application_name: None,
808 tls_mode: TlsMode::Disable,
809 connect_timeout: Duration::from_secs(5),
810 query_timeout: Duration::from_secs(5),
811 tls_config: crate::backend::tls::default_client_config(),
812 };
813 let bytes = build_startup(&cfg);
814 assert_eq!(
816 u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
817 196608
818 );
819 assert!(bytes
820 .windows(5)
821 .any(|w| w == b"user\0"));
822 assert!(bytes
823 .windows(10)
824 .any(|w| w == b"database\0a"));
825 }
826
827 #[test]
828 fn test_interpolate_params_basic() {
829 let params = vec![
830 ParamValue::Int(42),
831 ParamValue::Text("alice".into()),
832 ];
833 let sql = "SELECT * FROM t WHERE id = $1 AND name = $2";
834 let out = interpolate_params(sql, ¶ms).unwrap();
835 assert_eq!(out, "SELECT * FROM t WHERE id = 42 AND name = 'alice'");
836 }
837
838 #[test]
839 fn test_interpolate_params_escapes_quotes() {
840 let params = vec![ParamValue::Text("o'brien".into())];
841 let out =
842 interpolate_params("SELECT * FROM t WHERE name = $1", ¶ms).unwrap();
843 assert_eq!(out, "SELECT * FROM t WHERE name = 'o''brien'");
844 }
845
846 #[test]
847 fn test_interpolate_params_leaves_dollar_in_string_alone() {
848 let params = vec![ParamValue::Int(1)];
849 let sql = "SELECT '$1' AS lit, $1 AS val";
850 let out = interpolate_params(sql, ¶ms).unwrap();
851 assert_eq!(out, "SELECT '$1' AS lit, 1 AS val");
852 }
853
854 #[test]
855 fn test_interpolate_params_out_of_range() {
856 let params = vec![ParamValue::Int(1)];
857 let err = interpolate_params("SELECT $2", ¶ms).unwrap_err();
858 assert!(matches!(err, BackendError::Protocol(_)));
859 }
860
861 #[test]
862 fn test_parse_row_description_shape() {
863 let mut p = BytesMut::new();
865 p.put_u16(1);
866 p.extend_from_slice(b"x");
867 p.put_u8(0);
868 p.put_u32(0); p.put_u16(0); p.put_u32(23); p.put_i16(4);
872 p.put_i32(-1);
873 p.put_u16(0);
874 let cols = parse_row_description(&p);
875 assert_eq!(cols.len(), 1);
876 assert_eq!(cols[0].name, "x");
877 assert_eq!(cols[0].type_oid, 23);
878 }
879
880 #[test]
881 fn test_parse_data_row_with_null() {
882 let mut p = BytesMut::new();
884 p.put_u16(2);
885 p.put_i32(1);
886 p.extend_from_slice(b"a");
887 p.put_i32(-1);
888 let row = parse_data_row(&p, 2).unwrap();
889 assert_eq!(row.len(), 2);
890 assert_eq!(row[0], TextValue::Text("a".into()));
891 assert_eq!(row[1], TextValue::Null);
892 }
893
894 #[test]
895 fn test_error_message_extracts_m_field() {
896 let mut p = Vec::new();
897 p.push(b'S');
898 p.extend_from_slice(b"ERROR\0");
899 p.push(b'C');
900 p.extend_from_slice(b"28P01\0");
901 p.push(b'M');
902 p.extend_from_slice(b"password authentication failed\0");
903 p.push(0);
904 assert_eq!(error_message(&p), "password authentication failed");
905 }
906
907 #[test]
908 fn test_parse_parameter_status() {
909 let mut p = Vec::new();
910 p.extend_from_slice(b"client_encoding\0");
911 p.extend_from_slice(b"UTF8\0");
912 let (k, v) = parse_parameter_status(&p).unwrap();
913 assert_eq!(k, "client_encoding");
914 assert_eq!(v, "UTF8");
915 }
916
917 #[test]
918 fn test_parse_sasl_mechanisms() {
919 let mut p = Vec::new();
920 p.extend_from_slice(b"SCRAM-SHA-256\0");
921 p.extend_from_slice(b"SCRAM-SHA-256-PLUS\0");
922 p.push(0);
923 let m = parse_sasl_mechanisms(&p);
924 assert_eq!(m.len(), 2);
925 assert_eq!(m[0], "SCRAM-SHA-256");
926 assert_eq!(m[1], "SCRAM-SHA-256-PLUS");
927 }
928
929 #[test]
930 fn test_generate_nonce_is_url_safe() {
931 let n = generate_nonce();
932 assert!(n.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
933 assert!(n.len() >= 18);
934 }
935
936 #[test]
937 fn test_query_result_rows_affected() {
938 let r = QueryResult {
939 columns: Vec::new(),
940 rows: Vec::new(),
941 command_tag: "INSERT 0 5".into(),
942 };
943 assert_eq!(r.rows_affected(), Some(5));
944 let r = QueryResult {
945 columns: Vec::new(),
946 rows: Vec::new(),
947 command_tag: "SET".into(),
948 };
949 assert_eq!(r.rows_affected(), None);
950 }
951}