1use super::auth::{md5_password_response, Scram};
14use super::error::{BackendError, BackendResult};
15use super::stream::Stream;
16use super::tls::{negotiate, TlsMode};
17use super::types::{encode_literal, ParamValue, TextValue};
18use crate::protocol::{Message, MessageType, ProtocolCodec};
19use bytes::{Buf, BufMut, BytesMut};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::TcpStream;
24
25#[derive(Debug, Clone)]
27pub struct BackendConfig {
28 pub host: String,
30 pub port: u16,
32 pub user: String,
34 pub password: Option<String>,
37 pub database: Option<String>,
39 pub application_name: Option<String>,
42 pub tls_mode: TlsMode,
44 pub connect_timeout: Duration,
46 pub query_timeout: Duration,
48 pub tls_config: Arc<rustls::ClientConfig>,
51}
52
53impl BackendConfig {
54 pub fn address(&self) -> String {
55 format!("{}:{}", self.host, self.port)
56 }
57}
58
59pub struct BackendClient {
61 stream: Stream,
62 pub server_parameters: std::collections::HashMap<String, String>,
65 pub backend_pid: Option<u32>,
67 pub backend_secret: Option<u32>,
68}
69
70impl BackendClient {
71 pub async fn connect(cfg: &BackendConfig) -> BackendResult<Self> {
75 tokio::time::timeout(cfg.connect_timeout, Self::connect_inner(cfg))
76 .await
77 .map_err(|_| {
78 BackendError::Io(std::io::Error::new(
79 std::io::ErrorKind::TimedOut,
80 format!(
81 "connect to {} exceeded {:?}",
82 cfg.address(),
83 cfg.connect_timeout
84 ),
85 ))
86 })?
87 }
88
89 async fn connect_inner(cfg: &BackendConfig) -> BackendResult<Self> {
90 let tcp = TcpStream::connect(cfg.address()).await?;
91 let mut stream = negotiate(tcp, cfg.tls_mode, cfg.tls_config.clone(), &cfg.host).await?;
92
93 let startup = build_startup(cfg);
95 stream.write_all(&startup).await?;
96
97 let mut server_parameters = std::collections::HashMap::new();
98 let mut backend_pid = None;
99 let mut backend_secret = None;
100 let mut buffer = BytesMut::with_capacity(4096);
101 let codec = ProtocolCodec::new();
102 let mut scram_state: Option<Scram> = None;
103
104 loop {
105 let msg = read_one(&mut stream, &mut buffer, &codec).await?;
106 match msg.msg_type {
107 MessageType::AuthRequest => {
108 handle_auth(&mut stream, &msg, cfg, &mut scram_state).await?;
109 }
110 MessageType::ParameterStatus => {
111 if let Some((k, v)) = parse_parameter_status(&msg.payload) {
112 server_parameters.insert(k, v);
113 }
114 }
115 MessageType::BackendKeyData => {
116 if msg.payload.len() >= 8 {
117 backend_pid =
118 Some(u32::from_be_bytes(msg.payload[0..4].try_into().unwrap()));
119 backend_secret =
120 Some(u32::from_be_bytes(msg.payload[4..8].try_into().unwrap()));
121 }
122 }
123 MessageType::ReadyForQuery => {
124 return Ok(Self {
125 stream,
126 server_parameters,
127 backend_pid,
128 backend_secret,
129 });
130 }
131 MessageType::ErrorResponse => {
132 return Err(BackendError::BackendError(error_message(&msg.payload)));
133 }
134 MessageType::NoticeResponse => {
135 }
137 other => {
138 return Err(BackendError::Protocol(format!(
139 "unexpected message during startup: {:?}",
140 other
141 )));
142 }
143 }
144 }
145 }
146
147 pub async fn simple_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
152 self.run_query(sql).await
153 }
154
155 pub async fn query_with_params(
159 &mut self,
160 sql: &str,
161 params: &[ParamValue],
162 ) -> BackendResult<QueryResult> {
163 let substituted = interpolate_params(sql, params)?;
164 self.run_query(&substituted).await
165 }
166
167 pub async fn query_scalar(&mut self, sql: &str) -> BackendResult<TextValue> {
169 let res = self.simple_query(sql).await?;
170 if res.rows.len() != 1 {
171 return Err(BackendError::Protocol(format!(
172 "expected 1 row, got {}",
173 res.rows.len()
174 )));
175 }
176 if res.columns.len() != 1 {
177 return Err(BackendError::Protocol(format!(
178 "expected 1 column, got {}",
179 res.columns.len()
180 )));
181 }
182 Ok(res
183 .rows
184 .into_iter()
185 .next()
186 .unwrap()
187 .into_iter()
188 .next()
189 .unwrap())
190 }
191
192 pub async fn execute(&mut self, sql: &str) -> BackendResult<String> {
195 let res = self.simple_query(sql).await?;
196 Ok(res.command_tag)
197 }
198
199 pub async fn copy_in(&mut self, copy_sql: &str, data: &[u8]) -> BackendResult<u64> {
208 let t = Duration::from_secs(600);
211 tokio::time::timeout(t, Self::copy_in_inner(&mut self.stream, copy_sql, data))
212 .await
213 .map_err(|_| {
214 BackendError::Io(std::io::Error::new(
215 std::io::ErrorKind::TimedOut,
216 format!("COPY exceeded {:?}", t),
217 ))
218 })?
219 }
220
221 async fn copy_in_inner(stream: &mut Stream, copy_sql: &str, data: &[u8]) -> BackendResult<u64> {
222 let mut payload = BytesMut::with_capacity(copy_sql.len() + 1);
224 payload.extend_from_slice(copy_sql.as_bytes());
225 payload.put_u8(0);
226 stream
227 .write_all(&Message::new(MessageType::Query, payload).encode())
228 .await?;
229
230 let mut buffer = BytesMut::with_capacity(8192);
231 let codec = ProtocolCodec::new();
232
233 loop {
238 let msg = read_one(stream, &mut buffer, &codec).await?;
239 match msg.msg_type {
240 MessageType::Unknown(b'G') => break,
241 MessageType::ErrorResponse => {
242 let e = error_message(&msg.payload);
243 Self::drain_to_ready(stream, &mut buffer, &codec).await?;
244 return Err(BackendError::BackendError(e));
245 }
246 MessageType::ReadyForQuery => {
247 return Err(BackendError::Protocol(
248 "COPY: ReadyForQuery before CopyInResponse".into(),
249 ));
250 }
251 _ => {} }
253 }
254
255 const CHUNK: usize = 64 * 1024;
257 let mut off = 0;
258 while off < data.len() {
259 let end = (off + CHUNK).min(data.len());
260 let mut p = BytesMut::with_capacity(end - off);
261 p.extend_from_slice(&data[off..end]);
262 stream
263 .write_all(&Message::new(MessageType::CopyData, p).encode())
264 .await?;
265 off = end;
266 }
267 stream
268 .write_all(&Message::new(MessageType::CopyDone, BytesMut::new()).encode())
269 .await?;
270
271 let mut tag = String::new();
273 let mut last_error = None;
274 loop {
275 let msg = read_one(stream, &mut buffer, &codec).await?;
276 match msg.msg_type {
277 MessageType::CommandComplete | MessageType::Close => {
278 tag = parse_cstring(&msg.payload);
279 }
280 MessageType::ErrorResponse => {
281 last_error = Some(error_message(&msg.payload));
282 }
283 MessageType::ReadyForQuery => {
284 if let Some(e) = last_error {
285 return Err(BackendError::BackendError(e));
286 }
287 let n = tag
289 .rsplit(' ')
290 .next()
291 .and_then(|s| s.parse::<u64>().ok())
292 .unwrap_or(0);
293 return Ok(n);
294 }
295 _ => {}
296 }
297 }
298 }
299
300 async fn drain_to_ready(
301 stream: &mut Stream,
302 buffer: &mut BytesMut,
303 codec: &ProtocolCodec,
304 ) -> BackendResult<()> {
305 loop {
306 if read_one(stream, buffer, codec).await?.msg_type == MessageType::ReadyForQuery {
307 return Ok(());
308 }
309 }
310 }
311
312 async fn run_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
313 let t = self.stream_query_timeout();
314 tokio::time::timeout(t, Self::run_query_inner(&mut self.stream, sql))
315 .await
316 .map_err(|_| {
317 BackendError::Io(std::io::Error::new(
318 std::io::ErrorKind::TimedOut,
319 format!("query exceeded {:?}: {}", t, truncate(sql, 64)),
320 ))
321 })?
322 }
323
324 fn stream_query_timeout(&self) -> Duration {
325 Duration::from_secs(30)
328 }
329
330 async fn run_query_inner(stream: &mut Stream, sql: &str) -> BackendResult<QueryResult> {
331 let mut payload = BytesMut::with_capacity(sql.len() + 1);
333 payload.extend_from_slice(sql.as_bytes());
334 payload.put_u8(0);
335 let frame = Message::new(MessageType::Query, payload).encode();
336 stream.write_all(&frame).await?;
337
338 let mut buffer = BytesMut::with_capacity(8192);
339 let codec = ProtocolCodec::new();
340 let mut columns: Vec<ColumnMeta> = Vec::new();
341 let mut rows: Vec<Vec<TextValue>> = Vec::new();
342 let mut command_tag = String::new();
343 let mut last_error: Option<String> = None;
344
345 loop {
346 let msg = read_one(stream, &mut buffer, &codec).await?;
347 match msg.msg_type {
348 MessageType::RowDescription => {
355 columns = parse_row_description(&msg.payload);
356 }
357 MessageType::DataRow => {
358 let row = parse_data_row(&msg.payload, columns.len())?;
359 rows.push(row);
360 }
361 MessageType::CommandComplete | MessageType::Close => {
366 command_tag = parse_cstring(&msg.payload);
367 }
368 MessageType::EmptyQueryResponse => {
369 command_tag = String::new();
370 }
371 MessageType::ErrorResponse => {
372 last_error = Some(error_message(&msg.payload));
373 }
374 MessageType::NoticeResponse => {
375 tracing::debug!(notice = %error_message(&msg.payload), "backend notice");
376 }
377 MessageType::ReadyForQuery => {
378 if let Some(e) = last_error {
379 return Err(BackendError::BackendError(e));
380 }
381 return Ok(QueryResult {
382 columns,
383 rows,
384 command_tag,
385 });
386 }
387 MessageType::ParameterStatus => {
388 }
392 _other => {
393 }
397 }
398 }
399 }
400
401 pub async fn close(mut self) {
403 let term = Message::new(MessageType::Terminate, BytesMut::new()).encode();
404 let _ = self.stream.write_all(&term).await;
405 let _ = self.stream.shutdown().await;
406 }
407
408 pub fn is_tls(&self) -> bool {
410 self.stream.is_tls()
411 }
412}
413
414#[derive(Debug, Clone)]
419pub struct ColumnMeta {
420 pub name: String,
421 pub type_oid: u32,
422}
423
424#[derive(Debug, Clone)]
425pub struct QueryResult {
426 pub columns: Vec<ColumnMeta>,
427 pub rows: Vec<Vec<TextValue>>,
428 pub command_tag: String,
429}
430
431impl QueryResult {
432 pub fn rows_affected(&self) -> Option<u64> {
435 self.command_tag
436 .split_whitespace()
437 .last()
438 .and_then(|s| s.parse().ok())
439 }
440}
441
442fn build_startup(cfg: &BackendConfig) -> Vec<u8> {
447 let mut payload = BytesMut::with_capacity(128);
448 payload.put_u32(196608);
450 put_cstring(&mut payload, "user");
451 put_cstring(&mut payload, &cfg.user);
452 if let Some(db) = &cfg.database {
453 put_cstring(&mut payload, "database");
454 put_cstring(&mut payload, db);
455 }
456 put_cstring(&mut payload, "application_name");
457 put_cstring(
458 &mut payload,
459 cfg.application_name.as_deref().unwrap_or("heliosdb-proxy"),
460 );
461 put_cstring(&mut payload, "client_encoding");
462 put_cstring(&mut payload, "UTF8");
463 payload.put_u8(0); let mut framed = BytesMut::with_capacity(payload.len() + 4);
466 framed.put_u32((payload.len() + 4) as u32);
467 framed.extend_from_slice(&payload);
468 framed.to_vec()
469}
470
471fn put_cstring(buf: &mut BytesMut, s: &str) {
472 buf.extend_from_slice(s.as_bytes());
473 buf.put_u8(0);
474}
475
476fn parse_cstring(payload: &[u8]) -> String {
477 let end = payload
478 .iter()
479 .position(|&b| b == 0)
480 .unwrap_or(payload.len());
481 String::from_utf8_lossy(&payload[..end]).into_owned()
482}
483
484fn parse_parameter_status(payload: &[u8]) -> Option<(String, String)> {
485 let end1 = payload.iter().position(|&b| b == 0)?;
486 let key = String::from_utf8_lossy(&payload[..end1]).into_owned();
487 let rest = &payload[end1 + 1..];
488 let end2 = rest.iter().position(|&b| b == 0).unwrap_or(rest.len());
489 let value = String::from_utf8_lossy(&rest[..end2]).into_owned();
490 Some((key, value))
491}
492
493fn parse_row_description(payload: &[u8]) -> Vec<ColumnMeta> {
494 let mut p = BytesMut::from(payload);
495 if p.remaining() < 2 {
496 return Vec::new();
497 }
498 let n = p.get_u16() as usize;
499 let mut cols = Vec::with_capacity(n);
500 for _ in 0..n {
501 let end = match p.as_ref().iter().position(|&b| b == 0) {
503 Some(i) => i,
504 None => break,
505 };
506 let name = String::from_utf8_lossy(&p.as_ref()[..end]).into_owned();
507 p.advance(end + 1);
508 if p.remaining() < 18 {
509 break;
510 }
511 let _table_oid = p.get_u32();
512 let _column_number = p.get_u16();
513 let type_oid = p.get_u32();
514 let _type_len = p.get_i16();
515 let _type_mod = p.get_i32();
516 let _format_code = p.get_u16();
517 cols.push(ColumnMeta { name, type_oid });
518 }
519 cols
520}
521
522fn parse_data_row(payload: &[u8], column_count: usize) -> BackendResult<Vec<TextValue>> {
523 let mut p = BytesMut::from(payload);
524 if p.remaining() < 2 {
525 return Err(BackendError::Protocol("truncated DataRow".into()));
526 }
527 let n = p.get_u16() as usize;
528 let mut out = Vec::with_capacity(n);
529 for _ in 0..n {
530 if p.remaining() < 4 {
531 return Err(BackendError::Protocol("truncated DataRow field".into()));
532 }
533 let len = p.get_i32();
534 if len == -1 {
535 out.push(TextValue::Null);
536 } else {
537 let len = len as usize;
538 if p.remaining() < len {
539 return Err(BackendError::Protocol("truncated DataRow value".into()));
540 }
541 let bytes = p.split_to(len);
542 out.push(TextValue::Text(
543 String::from_utf8_lossy(&bytes).into_owned(),
544 ));
545 }
546 }
547 let _ = column_count;
548 Ok(out)
549}
550
551fn error_message(payload: &[u8]) -> String {
552 let mut i = 0;
555 let mut msg_field = None;
556 while i < payload.len() {
557 let code = payload[i];
558 if code == 0 {
559 break;
560 }
561 i += 1;
562 let end = match payload[i..].iter().position(|&b| b == 0) {
563 Some(e) => i + e,
564 None => payload.len(),
565 };
566 let value = String::from_utf8_lossy(&payload[i..end]).into_owned();
567 if code == b'M' {
568 msg_field = Some(value);
569 }
570 i = end + 1;
571 }
572 msg_field.unwrap_or_else(|| "<no message>".to_string())
573}
574
575async fn read_one(
576 stream: &mut Stream,
577 buffer: &mut BytesMut,
578 codec: &ProtocolCodec,
579) -> BackendResult<Message> {
580 loop {
581 if let Some(mut msg) = codec
582 .decode_message(buffer)
583 .map_err(|e| BackendError::Protocol(e.to_string()))?
584 {
585 msg.msg_type = match msg.msg_type {
592 MessageType::Sync => MessageType::ParameterStatus,
593 MessageType::Describe => MessageType::DataRow,
594 MessageType::Execute => MessageType::ErrorResponse,
595 MessageType::Close => MessageType::CommandComplete,
596 other => other,
597 };
598 return Ok(msg);
599 }
600 let mut tmp = vec![0u8; 4096];
601 let n = stream.read(&mut tmp).await?;
602 if n == 0 {
603 return Err(BackendError::Closed);
604 }
605 buffer.extend_from_slice(&tmp[..n]);
606 }
607}
608
609async fn handle_auth(
610 stream: &mut Stream,
611 msg: &Message,
612 cfg: &BackendConfig,
613 scram_state: &mut Option<Scram>,
614) -> BackendResult<()> {
615 if msg.payload.len() < 4 {
616 return Err(BackendError::Protocol(
617 "AuthRequest payload < 4 bytes".into(),
618 ));
619 }
620 let code = u32::from_be_bytes([
621 msg.payload[0],
622 msg.payload[1],
623 msg.payload[2],
624 msg.payload[3],
625 ]);
626 match code {
627 0 => Ok(()), 5 => {
629 if msg.payload.len() < 8 {
631 return Err(BackendError::Protocol("AuthenticationMD5 truncated".into()));
632 }
633 let salt: [u8; 4] = [
634 msg.payload[4],
635 msg.payload[5],
636 msg.payload[6],
637 msg.payload[7],
638 ];
639 let password = cfg.password.as_deref().ok_or_else(|| {
640 BackendError::Auth("server requested MD5 but no password configured".into())
641 })?;
642 let payload = md5_password_response(&cfg.user, password, &salt);
643 write_password_message(stream, &payload).await
644 }
645 3 => {
646 let password = cfg.password.as_deref().ok_or_else(|| {
648 BackendError::Auth("server requested password but none configured".into())
649 })?;
650 let mut payload = Vec::with_capacity(password.len() + 1);
651 payload.extend_from_slice(password.as_bytes());
652 payload.push(0);
653 write_password_message(stream, &payload).await
654 }
655 10 => {
656 let mechs = parse_sasl_mechanisms(&msg.payload[4..]);
658 if !mechs.iter().any(|m| m == "SCRAM-SHA-256") {
659 return Err(BackendError::Auth(format!(
660 "no supported SASL mechanism; server offered {:?}",
661 mechs
662 )));
663 }
664 let nonce = generate_nonce();
665 let (scram, first) = Scram::client_first(nonce);
666 *scram_state = Some(scram);
667 write_password_message(stream, &first.0).await
668 }
669 11 => {
670 let scram = scram_state
672 .as_mut()
673 .ok_or_else(|| BackendError::Auth("SASLContinue before SASL start".into()))?;
674 let password = cfg
675 .password
676 .as_deref()
677 .ok_or_else(|| BackendError::Auth("SCRAM requires a password".into()))?;
678 let out = scram.client_final(&msg.payload[4..], password)?;
679 write_password_message(stream, &out.0).await
680 }
681 12 => {
682 let scram = scram_state
684 .as_ref()
685 .ok_or_else(|| BackendError::Auth("SASLFinal before SASL start".into()))?;
686 scram.verify_server(&msg.payload[4..])
687 }
688 other => Err(BackendError::Auth(format!(
689 "unsupported authentication request code: {}",
690 other
691 ))),
692 }
693}
694
695async fn write_password_message(stream: &mut Stream, payload: &[u8]) -> BackendResult<()> {
696 let mut buf = BytesMut::with_capacity(payload.len() + 5);
697 buf.put_u8(b'p');
698 buf.put_u32((payload.len() + 4) as u32);
699 buf.extend_from_slice(payload);
700 stream.write_all(&buf).await?;
701 Ok(())
702}
703
704fn parse_sasl_mechanisms(payload: &[u8]) -> Vec<String> {
705 let mut out = Vec::new();
706 let mut i = 0;
707 while i < payload.len() {
708 let end = match payload[i..].iter().position(|&b| b == 0) {
709 Some(e) => i + e,
710 None => payload.len(),
711 };
712 if end == i {
713 break; }
715 out.push(String::from_utf8_lossy(&payload[i..end]).into_owned());
716 i = end + 1;
717 }
718 out
719}
720
721fn generate_nonce() -> String {
722 use base64::Engine as _;
723 use rand::RngCore;
724 let mut bytes = [0u8; 18];
725 rand::thread_rng().fill_bytes(&mut bytes);
726 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
728}
729
730fn interpolate_params(sql: &str, params: &[ParamValue]) -> BackendResult<String> {
731 let mut out = String::with_capacity(sql.len());
736 let bytes = sql.as_bytes();
737 let mut i = 0;
738 let mut in_string = false;
739 let mut quote_char = 0u8;
740 while i < bytes.len() {
741 let b = bytes[i];
742 if in_string {
743 out.push(b as char);
744 if b == quote_char {
745 if i + 1 < bytes.len() && bytes[i + 1] == quote_char {
747 out.push(quote_char as char);
748 i += 2;
749 continue;
750 }
751 in_string = false;
752 }
753 i += 1;
754 continue;
755 }
756 if b == b'\'' || b == b'"' {
757 in_string = true;
758 quote_char = b;
759 out.push(b as char);
760 i += 1;
761 continue;
762 }
763 if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
764 let mut j = i + 1;
765 while j < bytes.len() && bytes[j].is_ascii_digit() {
766 j += 1;
767 }
768 let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
769 .unwrap()
770 .parse()
771 .map_err(|_| {
772 BackendError::Protocol(format!("invalid parameter reference at byte {}", i))
773 })?;
774 if idx == 0 || idx > params.len() {
775 return Err(BackendError::Protocol(format!(
776 "parameter ${} out of range (have {})",
777 idx,
778 params.len()
779 )));
780 }
781 out.push_str(&encode_literal(¶ms[idx - 1]));
782 i = j;
783 continue;
784 }
785 out.push(b as char);
786 i += 1;
787 }
788 Ok(out)
789}
790
791fn truncate(s: &str, n: usize) -> &str {
792 match s.char_indices().nth(n) {
793 Some((i, _)) => &s[..i],
794 None => s,
795 }
796}
797
798#[cfg(test)]
799mod tests {
800 use super::*;
801 use crate::backend::types::ParamValue;
802
803 #[test]
804 fn test_build_startup_has_user_and_protocol_version() {
805 let cfg = BackendConfig {
806 host: "localhost".into(),
807 port: 5432,
808 user: "alice".into(),
809 password: None,
810 database: Some("app".into()),
811 application_name: None,
812 tls_mode: TlsMode::Disable,
813 connect_timeout: Duration::from_secs(5),
814 query_timeout: Duration::from_secs(5),
815 tls_config: crate::backend::tls::default_client_config(),
816 };
817 let bytes = build_startup(&cfg);
818 assert_eq!(
820 u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
821 196608
822 );
823 assert!(bytes.windows(5).any(|w| w == b"user\0"));
824 assert!(bytes.windows(10).any(|w| w == b"database\0a"));
825 }
826
827 #[test]
828 fn test_interpolate_params_basic() {
829 let params = vec![ParamValue::Int(42), ParamValue::Text("alice".into())];
830 let sql = "SELECT * FROM t WHERE id = $1 AND name = $2";
831 let out = interpolate_params(sql, ¶ms).unwrap();
832 assert_eq!(out, "SELECT * FROM t WHERE id = 42 AND name = 'alice'");
833 }
834
835 #[test]
836 fn test_interpolate_params_escapes_quotes() {
837 let params = vec![ParamValue::Text("o'brien".into())];
838 let out = interpolate_params("SELECT * FROM t WHERE name = $1", ¶ms).unwrap();
839 assert_eq!(out, "SELECT * FROM t WHERE name = 'o''brien'");
840 }
841
842 #[test]
843 fn test_interpolate_params_leaves_dollar_in_string_alone() {
844 let params = vec![ParamValue::Int(1)];
845 let sql = "SELECT '$1' AS lit, $1 AS val";
846 let out = interpolate_params(sql, ¶ms).unwrap();
847 assert_eq!(out, "SELECT '$1' AS lit, 1 AS val");
848 }
849
850 #[test]
851 fn test_interpolate_params_out_of_range() {
852 let params = vec![ParamValue::Int(1)];
853 let err = interpolate_params("SELECT $2", ¶ms).unwrap_err();
854 assert!(matches!(err, BackendError::Protocol(_)));
855 }
856
857 #[test]
858 fn test_parse_row_description_shape() {
859 let mut p = BytesMut::new();
861 p.put_u16(1);
862 p.extend_from_slice(b"x");
863 p.put_u8(0);
864 p.put_u32(0); p.put_u16(0); p.put_u32(23); p.put_i16(4);
868 p.put_i32(-1);
869 p.put_u16(0);
870 let cols = parse_row_description(&p);
871 assert_eq!(cols.len(), 1);
872 assert_eq!(cols[0].name, "x");
873 assert_eq!(cols[0].type_oid, 23);
874 }
875
876 #[test]
877 fn test_parse_data_row_with_null() {
878 let mut p = BytesMut::new();
880 p.put_u16(2);
881 p.put_i32(1);
882 p.extend_from_slice(b"a");
883 p.put_i32(-1);
884 let row = parse_data_row(&p, 2).unwrap();
885 assert_eq!(row.len(), 2);
886 assert_eq!(row[0], TextValue::Text("a".into()));
887 assert_eq!(row[1], TextValue::Null);
888 }
889
890 #[test]
891 fn test_error_message_extracts_m_field() {
892 let mut p = Vec::new();
893 p.push(b'S');
894 p.extend_from_slice(b"ERROR\0");
895 p.push(b'C');
896 p.extend_from_slice(b"28P01\0");
897 p.push(b'M');
898 p.extend_from_slice(b"password authentication failed\0");
899 p.push(0);
900 assert_eq!(error_message(&p), "password authentication failed");
901 }
902
903 #[test]
904 fn test_parse_parameter_status() {
905 let mut p = Vec::new();
906 p.extend_from_slice(b"client_encoding\0");
907 p.extend_from_slice(b"UTF8\0");
908 let (k, v) = parse_parameter_status(&p).unwrap();
909 assert_eq!(k, "client_encoding");
910 assert_eq!(v, "UTF8");
911 }
912
913 #[test]
914 fn test_parse_sasl_mechanisms() {
915 let mut p = Vec::new();
916 p.extend_from_slice(b"SCRAM-SHA-256\0");
917 p.extend_from_slice(b"SCRAM-SHA-256-PLUS\0");
918 p.push(0);
919 let m = parse_sasl_mechanisms(&p);
920 assert_eq!(m.len(), 2);
921 assert_eq!(m[0], "SCRAM-SHA-256");
922 assert_eq!(m[1], "SCRAM-SHA-256-PLUS");
923 }
924
925 #[test]
926 fn test_generate_nonce_is_url_safe() {
927 let n = generate_nonce();
928 assert!(n
929 .chars()
930 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
931 assert!(n.len() >= 18);
932 }
933
934 #[test]
935 fn test_query_result_rows_affected() {
936 let r = QueryResult {
937 columns: Vec::new(),
938 rows: Vec::new(),
939 command_tag: "INSERT 0 5".into(),
940 };
941 assert_eq!(r.rows_affected(), Some(5));
942 let r = QueryResult {
943 columns: Vec::new(),
944 rows: Vec::new(),
945 command_tag: "SET".into(),
946 };
947 assert_eq!(r.rows_affected(), None);
948 }
949}