1use super::auth::{md5_password_response, Scram};
14use super::error::{BackendError, BackendResult};
15use super::stream::Stream;
16use super::tls::{negotiate, TlsMode};
17use super::types::{encode_literal, ParamValue, TextValue};
18use crate::protocol::{Message, MessageType, ProtocolCodec};
19use bytes::{Buf, BufMut, BytesMut};
20use std::sync::Arc;
21use std::time::Duration;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::TcpStream;
24
25#[derive(Debug, Clone)]
27pub struct BackendConfig {
28 pub host: String,
30 pub port: u16,
32 pub user: String,
34 pub password: Option<String>,
37 pub database: Option<String>,
39 pub application_name: Option<String>,
42 pub tls_mode: TlsMode,
44 pub connect_timeout: Duration,
46 pub query_timeout: Duration,
48 pub tls_config: Arc<rustls::ClientConfig>,
51}
52
53impl BackendConfig {
54 pub fn address(&self) -> String {
55 format!("{}:{}", self.host, self.port)
56 }
57}
58
59pub struct BackendClient {
61 stream: Stream,
62 pub server_parameters: std::collections::HashMap<String, String>,
65 pub backend_pid: Option<u32>,
67 pub backend_secret: Option<u32>,
68}
69
70impl BackendClient {
71 pub async fn connect(cfg: &BackendConfig) -> BackendResult<Self> {
75 tokio::time::timeout(cfg.connect_timeout, Self::connect_inner(cfg))
76 .await
77 .map_err(|_| BackendError::Io(std::io::Error::new(
78 std::io::ErrorKind::TimedOut,
79 format!("connect to {} exceeded {:?}", cfg.address(), cfg.connect_timeout),
80 )))?
81 }
82
83 async fn connect_inner(cfg: &BackendConfig) -> BackendResult<Self> {
84 let tcp = TcpStream::connect(cfg.address()).await?;
85 let mut stream =
86 negotiate(tcp, cfg.tls_mode, cfg.tls_config.clone(), &cfg.host).await?;
87
88 let startup = build_startup(cfg);
90 stream.write_all(&startup).await?;
91
92 let mut server_parameters = std::collections::HashMap::new();
93 let mut backend_pid = None;
94 let mut backend_secret = None;
95 let mut buffer = BytesMut::with_capacity(4096);
96 let codec = ProtocolCodec::new();
97 let mut scram_state: Option<Scram> = None;
98
99 loop {
100 let msg = read_one(&mut stream, &mut buffer, &codec).await?;
101 match msg.msg_type {
102 MessageType::AuthRequest => {
103 handle_auth(
104 &mut stream,
105 &msg,
106 cfg,
107 &mut scram_state,
108 )
109 .await?;
110 }
111 MessageType::ParameterStatus => {
112 if let Some((k, v)) = parse_parameter_status(&msg.payload) {
113 server_parameters.insert(k, v);
114 }
115 }
116 MessageType::BackendKeyData => {
117 if msg.payload.len() >= 8 {
118 backend_pid = Some(u32::from_be_bytes(
119 msg.payload[0..4].try_into().unwrap(),
120 ));
121 backend_secret = Some(u32::from_be_bytes(
122 msg.payload[4..8].try_into().unwrap(),
123 ));
124 }
125 }
126 MessageType::ReadyForQuery => {
127 return Ok(Self {
128 stream,
129 server_parameters,
130 backend_pid,
131 backend_secret,
132 });
133 }
134 MessageType::ErrorResponse => {
135 return Err(BackendError::BackendError(error_message(&msg.payload)));
136 }
137 MessageType::NoticeResponse => {
138 }
140 other => {
141 return Err(BackendError::Protocol(format!(
142 "unexpected message during startup: {:?}",
143 other
144 )));
145 }
146 }
147 }
148 }
149
150 pub async fn simple_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
155 self.run_query(sql).await
156 }
157
158 pub async fn query_with_params(
162 &mut self,
163 sql: &str,
164 params: &[ParamValue],
165 ) -> BackendResult<QueryResult> {
166 let substituted = interpolate_params(sql, params)?;
167 self.run_query(&substituted).await
168 }
169
170 pub async fn query_scalar(&mut self, sql: &str) -> BackendResult<TextValue> {
172 let res = self.simple_query(sql).await?;
173 if res.rows.len() != 1 {
174 return Err(BackendError::Protocol(format!(
175 "expected 1 row, got {}",
176 res.rows.len()
177 )));
178 }
179 if res.columns.len() != 1 {
180 return Err(BackendError::Protocol(format!(
181 "expected 1 column, got {}",
182 res.columns.len()
183 )));
184 }
185 Ok(res.rows.into_iter().next().unwrap().into_iter().next().unwrap())
186 }
187
188 pub async fn execute(&mut self, sql: &str) -> BackendResult<String> {
191 let res = self.simple_query(sql).await?;
192 Ok(res.command_tag)
193 }
194
195 async fn run_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
196 let t = self.stream_query_timeout();
197 tokio::time::timeout(t, Self::run_query_inner(&mut self.stream, sql))
198 .await
199 .map_err(|_| BackendError::Io(std::io::Error::new(
200 std::io::ErrorKind::TimedOut,
201 format!("query exceeded {:?}: {}", t, truncate(sql, 64)),
202 )))?
203 }
204
205 fn stream_query_timeout(&self) -> Duration {
206 Duration::from_secs(30)
209 }
210
211 async fn run_query_inner(stream: &mut Stream, sql: &str) -> BackendResult<QueryResult> {
212 let mut payload = BytesMut::with_capacity(sql.len() + 1);
214 payload.extend_from_slice(sql.as_bytes());
215 payload.put_u8(0);
216 let frame = Message::new(MessageType::Query, payload).encode();
217 stream.write_all(&frame).await?;
218
219 let mut buffer = BytesMut::with_capacity(8192);
220 let codec = ProtocolCodec::new();
221 let mut columns: Vec<ColumnMeta> = Vec::new();
222 let mut rows: Vec<Vec<TextValue>> = Vec::new();
223 let mut command_tag = String::new();
224 let mut last_error: Option<String> = None;
225
226 loop {
227 let msg = read_one(stream, &mut buffer, &codec).await?;
228 match msg.msg_type {
229 MessageType::RowDescription => {
236 columns = parse_row_description(&msg.payload);
237 }
238 MessageType::DataRow => {
239 let row = parse_data_row(&msg.payload, columns.len())?;
240 rows.push(row);
241 }
242 MessageType::CommandComplete | MessageType::Close => {
247 command_tag = parse_cstring(&msg.payload);
248 }
249 MessageType::EmptyQueryResponse => {
250 command_tag = String::new();
251 }
252 MessageType::ErrorResponse => {
253 last_error = Some(error_message(&msg.payload));
254 }
255 MessageType::NoticeResponse => {
256 tracing::debug!(notice = %error_message(&msg.payload), "backend notice");
257 }
258 MessageType::ReadyForQuery => {
259 if let Some(e) = last_error {
260 return Err(BackendError::BackendError(e));
261 }
262 return Ok(QueryResult {
263 columns,
264 rows,
265 command_tag,
266 });
267 }
268 MessageType::ParameterStatus => {
269 }
273 _other => {
274 }
278 }
279 }
280 }
281
282 pub async fn close(mut self) {
284 let term = Message::new(MessageType::Terminate, BytesMut::new()).encode();
285 let _ = self.stream.write_all(&term).await;
286 let _ = self.stream.shutdown().await;
287 }
288
289 pub fn is_tls(&self) -> bool {
291 self.stream.is_tls()
292 }
293}
294
295#[derive(Debug, Clone)]
300pub struct ColumnMeta {
301 pub name: String,
302 pub type_oid: u32,
303}
304
305#[derive(Debug, Clone)]
306pub struct QueryResult {
307 pub columns: Vec<ColumnMeta>,
308 pub rows: Vec<Vec<TextValue>>,
309 pub command_tag: String,
310}
311
312impl QueryResult {
313 pub fn rows_affected(&self) -> Option<u64> {
316 self.command_tag
317 .split_whitespace()
318 .last()
319 .and_then(|s| s.parse().ok())
320 }
321}
322
323fn build_startup(cfg: &BackendConfig) -> Vec<u8> {
328 let mut payload = BytesMut::with_capacity(128);
329 payload.put_u32(196608);
331 put_cstring(&mut payload, "user");
332 put_cstring(&mut payload, &cfg.user);
333 if let Some(db) = &cfg.database {
334 put_cstring(&mut payload, "database");
335 put_cstring(&mut payload, db);
336 }
337 put_cstring(&mut payload, "application_name");
338 put_cstring(
339 &mut payload,
340 cfg.application_name
341 .as_deref()
342 .unwrap_or("heliosdb-proxy"),
343 );
344 put_cstring(&mut payload, "client_encoding");
345 put_cstring(&mut payload, "UTF8");
346 payload.put_u8(0); let mut framed = BytesMut::with_capacity(payload.len() + 4);
349 framed.put_u32((payload.len() + 4) as u32);
350 framed.extend_from_slice(&payload);
351 framed.to_vec()
352}
353
354fn put_cstring(buf: &mut BytesMut, s: &str) {
355 buf.extend_from_slice(s.as_bytes());
356 buf.put_u8(0);
357}
358
359fn parse_cstring(payload: &[u8]) -> String {
360 let end = payload.iter().position(|&b| b == 0).unwrap_or(payload.len());
361 String::from_utf8_lossy(&payload[..end]).into_owned()
362}
363
364fn parse_parameter_status(payload: &[u8]) -> Option<(String, String)> {
365 let end1 = payload.iter().position(|&b| b == 0)?;
366 let key = String::from_utf8_lossy(&payload[..end1]).into_owned();
367 let rest = &payload[end1 + 1..];
368 let end2 = rest.iter().position(|&b| b == 0).unwrap_or(rest.len());
369 let value = String::from_utf8_lossy(&rest[..end2]).into_owned();
370 Some((key, value))
371}
372
373fn parse_row_description(payload: &[u8]) -> Vec<ColumnMeta> {
374 let mut p = BytesMut::from(payload);
375 if p.remaining() < 2 {
376 return Vec::new();
377 }
378 let n = p.get_u16() as usize;
379 let mut cols = Vec::with_capacity(n);
380 for _ in 0..n {
381 let end = match p.as_ref().iter().position(|&b| b == 0) {
383 Some(i) => i,
384 None => break,
385 };
386 let name = String::from_utf8_lossy(&p.as_ref()[..end]).into_owned();
387 p.advance(end + 1);
388 if p.remaining() < 18 {
389 break;
390 }
391 let _table_oid = p.get_u32();
392 let _column_number = p.get_u16();
393 let type_oid = p.get_u32();
394 let _type_len = p.get_i16();
395 let _type_mod = p.get_i32();
396 let _format_code = p.get_u16();
397 cols.push(ColumnMeta { name, type_oid });
398 }
399 cols
400}
401
402fn parse_data_row(payload: &[u8], column_count: usize) -> BackendResult<Vec<TextValue>> {
403 let mut p = BytesMut::from(payload);
404 if p.remaining() < 2 {
405 return Err(BackendError::Protocol("truncated DataRow".into()));
406 }
407 let n = p.get_u16() as usize;
408 let mut out = Vec::with_capacity(n);
409 for _ in 0..n {
410 if p.remaining() < 4 {
411 return Err(BackendError::Protocol("truncated DataRow field".into()));
412 }
413 let len = p.get_i32();
414 if len == -1 {
415 out.push(TextValue::Null);
416 } else {
417 let len = len as usize;
418 if p.remaining() < len {
419 return Err(BackendError::Protocol(
420 "truncated DataRow value".into(),
421 ));
422 }
423 let bytes = p.split_to(len);
424 out.push(TextValue::Text(
425 String::from_utf8_lossy(&bytes).into_owned(),
426 ));
427 }
428 }
429 let _ = column_count;
430 Ok(out)
431}
432
433fn error_message(payload: &[u8]) -> String {
434 let mut i = 0;
437 let mut msg_field = None;
438 while i < payload.len() {
439 let code = payload[i];
440 if code == 0 {
441 break;
442 }
443 i += 1;
444 let end = match payload[i..].iter().position(|&b| b == 0) {
445 Some(e) => i + e,
446 None => payload.len(),
447 };
448 let value = String::from_utf8_lossy(&payload[i..end]).into_owned();
449 if code == b'M' {
450 msg_field = Some(value);
451 }
452 i = end + 1;
453 }
454 msg_field.unwrap_or_else(|| "<no message>".to_string())
455}
456
457async fn read_one(
458 stream: &mut Stream,
459 buffer: &mut BytesMut,
460 codec: &ProtocolCodec,
461) -> BackendResult<Message> {
462 loop {
463 if let Some(msg) = codec
464 .decode_message(buffer)
465 .map_err(|e| BackendError::Protocol(e.to_string()))?
466 {
467 return Ok(msg);
468 }
469 let mut tmp = vec![0u8; 4096];
470 let n = stream.read(&mut tmp).await?;
471 if n == 0 {
472 return Err(BackendError::Closed);
473 }
474 buffer.extend_from_slice(&tmp[..n]);
475 }
476}
477
478async fn handle_auth(
479 stream: &mut Stream,
480 msg: &Message,
481 cfg: &BackendConfig,
482 scram_state: &mut Option<Scram>,
483) -> BackendResult<()> {
484 if msg.payload.len() < 4 {
485 return Err(BackendError::Protocol(
486 "AuthRequest payload < 4 bytes".into(),
487 ));
488 }
489 let code =
490 u32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
491 match code {
492 0 => Ok(()), 5 => {
494 if msg.payload.len() < 8 {
496 return Err(BackendError::Protocol("AuthenticationMD5 truncated".into()));
497 }
498 let salt: [u8; 4] = [
499 msg.payload[4],
500 msg.payload[5],
501 msg.payload[6],
502 msg.payload[7],
503 ];
504 let password = cfg.password.as_deref().ok_or_else(|| {
505 BackendError::Auth("server requested MD5 but no password configured".into())
506 })?;
507 let payload = md5_password_response(&cfg.user, password, &salt);
508 write_password_message(stream, &payload).await
509 }
510 3 => {
511 let password = cfg.password.as_deref().ok_or_else(|| {
513 BackendError::Auth("server requested password but none configured".into())
514 })?;
515 let mut payload = Vec::with_capacity(password.len() + 1);
516 payload.extend_from_slice(password.as_bytes());
517 payload.push(0);
518 write_password_message(stream, &payload).await
519 }
520 10 => {
521 let mechs = parse_sasl_mechanisms(&msg.payload[4..]);
523 if !mechs.iter().any(|m| m == "SCRAM-SHA-256") {
524 return Err(BackendError::Auth(format!(
525 "no supported SASL mechanism; server offered {:?}",
526 mechs
527 )));
528 }
529 let nonce = generate_nonce();
530 let (scram, first) = Scram::client_first(nonce);
531 *scram_state = Some(scram);
532 write_password_message(stream, &first.0).await
533 }
534 11 => {
535 let scram = scram_state.as_mut().ok_or_else(|| {
537 BackendError::Auth("SASLContinue before SASL start".into())
538 })?;
539 let password = cfg.password.as_deref().ok_or_else(|| {
540 BackendError::Auth("SCRAM requires a password".into())
541 })?;
542 let out = scram.client_final(&msg.payload[4..], password)?;
543 write_password_message(stream, &out.0).await
544 }
545 12 => {
546 let scram = scram_state.as_ref().ok_or_else(|| {
548 BackendError::Auth("SASLFinal before SASL start".into())
549 })?;
550 scram.verify_server(&msg.payload[4..])
551 }
552 other => Err(BackendError::Auth(format!(
553 "unsupported authentication request code: {}",
554 other
555 ))),
556 }
557}
558
559async fn write_password_message(
560 stream: &mut Stream,
561 payload: &[u8],
562) -> BackendResult<()> {
563 let mut buf = BytesMut::with_capacity(payload.len() + 5);
564 buf.put_u8(b'p');
565 buf.put_u32((payload.len() + 4) as u32);
566 buf.extend_from_slice(payload);
567 stream.write_all(&buf).await?;
568 Ok(())
569}
570
571fn parse_sasl_mechanisms(payload: &[u8]) -> Vec<String> {
572 let mut out = Vec::new();
573 let mut i = 0;
574 while i < payload.len() {
575 let end = match payload[i..].iter().position(|&b| b == 0) {
576 Some(e) => i + e,
577 None => payload.len(),
578 };
579 if end == i {
580 break; }
582 out.push(String::from_utf8_lossy(&payload[i..end]).into_owned());
583 i = end + 1;
584 }
585 out
586}
587
588fn generate_nonce() -> String {
589 use base64::Engine as _;
590 use rand::RngCore;
591 let mut bytes = [0u8; 18];
592 rand::thread_rng().fill_bytes(&mut bytes);
593 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
595}
596
597fn interpolate_params(sql: &str, params: &[ParamValue]) -> BackendResult<String> {
598 let mut out = String::with_capacity(sql.len());
603 let bytes = sql.as_bytes();
604 let mut i = 0;
605 let mut in_string = false;
606 let mut quote_char = 0u8;
607 while i < bytes.len() {
608 let b = bytes[i];
609 if in_string {
610 out.push(b as char);
611 if b == quote_char {
612 if i + 1 < bytes.len() && bytes[i + 1] == quote_char {
614 out.push(quote_char as char);
615 i += 2;
616 continue;
617 }
618 in_string = false;
619 }
620 i += 1;
621 continue;
622 }
623 if b == b'\'' || b == b'"' {
624 in_string = true;
625 quote_char = b;
626 out.push(b as char);
627 i += 1;
628 continue;
629 }
630 if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
631 let mut j = i + 1;
632 while j < bytes.len() && bytes[j].is_ascii_digit() {
633 j += 1;
634 }
635 let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
636 .unwrap()
637 .parse()
638 .map_err(|_| {
639 BackendError::Protocol(format!(
640 "invalid parameter reference at byte {}",
641 i
642 ))
643 })?;
644 if idx == 0 || idx > params.len() {
645 return Err(BackendError::Protocol(format!(
646 "parameter ${} out of range (have {})",
647 idx,
648 params.len()
649 )));
650 }
651 out.push_str(&encode_literal(¶ms[idx - 1]));
652 i = j;
653 continue;
654 }
655 out.push(b as char);
656 i += 1;
657 }
658 Ok(out)
659}
660
661fn truncate(s: &str, n: usize) -> &str {
662 match s.char_indices().nth(n) {
663 Some((i, _)) => &s[..i],
664 None => s,
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671 use crate::backend::types::ParamValue;
672
673 #[test]
674 fn test_build_startup_has_user_and_protocol_version() {
675 let cfg = BackendConfig {
676 host: "localhost".into(),
677 port: 5432,
678 user: "alice".into(),
679 password: None,
680 database: Some("app".into()),
681 application_name: None,
682 tls_mode: TlsMode::Disable,
683 connect_timeout: Duration::from_secs(5),
684 query_timeout: Duration::from_secs(5),
685 tls_config: crate::backend::tls::default_client_config(),
686 };
687 let bytes = build_startup(&cfg);
688 assert_eq!(
690 u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
691 196608
692 );
693 assert!(bytes
694 .windows(5)
695 .any(|w| w == b"user\0"));
696 assert!(bytes
697 .windows(10)
698 .any(|w| w == b"database\0a"));
699 }
700
701 #[test]
702 fn test_interpolate_params_basic() {
703 let params = vec![
704 ParamValue::Int(42),
705 ParamValue::Text("alice".into()),
706 ];
707 let sql = "SELECT * FROM t WHERE id = $1 AND name = $2";
708 let out = interpolate_params(sql, ¶ms).unwrap();
709 assert_eq!(out, "SELECT * FROM t WHERE id = 42 AND name = 'alice'");
710 }
711
712 #[test]
713 fn test_interpolate_params_escapes_quotes() {
714 let params = vec![ParamValue::Text("o'brien".into())];
715 let out =
716 interpolate_params("SELECT * FROM t WHERE name = $1", ¶ms).unwrap();
717 assert_eq!(out, "SELECT * FROM t WHERE name = 'o''brien'");
718 }
719
720 #[test]
721 fn test_interpolate_params_leaves_dollar_in_string_alone() {
722 let params = vec![ParamValue::Int(1)];
723 let sql = "SELECT '$1' AS lit, $1 AS val";
724 let out = interpolate_params(sql, ¶ms).unwrap();
725 assert_eq!(out, "SELECT '$1' AS lit, 1 AS val");
726 }
727
728 #[test]
729 fn test_interpolate_params_out_of_range() {
730 let params = vec![ParamValue::Int(1)];
731 let err = interpolate_params("SELECT $2", ¶ms).unwrap_err();
732 assert!(matches!(err, BackendError::Protocol(_)));
733 }
734
735 #[test]
736 fn test_parse_row_description_shape() {
737 let mut p = BytesMut::new();
739 p.put_u16(1);
740 p.extend_from_slice(b"x");
741 p.put_u8(0);
742 p.put_u32(0); p.put_u16(0); p.put_u32(23); p.put_i16(4);
746 p.put_i32(-1);
747 p.put_u16(0);
748 let cols = parse_row_description(&p);
749 assert_eq!(cols.len(), 1);
750 assert_eq!(cols[0].name, "x");
751 assert_eq!(cols[0].type_oid, 23);
752 }
753
754 #[test]
755 fn test_parse_data_row_with_null() {
756 let mut p = BytesMut::new();
758 p.put_u16(2);
759 p.put_i32(1);
760 p.extend_from_slice(b"a");
761 p.put_i32(-1);
762 let row = parse_data_row(&p, 2).unwrap();
763 assert_eq!(row.len(), 2);
764 assert_eq!(row[0], TextValue::Text("a".into()));
765 assert_eq!(row[1], TextValue::Null);
766 }
767
768 #[test]
769 fn test_error_message_extracts_m_field() {
770 let mut p = Vec::new();
771 p.push(b'S');
772 p.extend_from_slice(b"ERROR\0");
773 p.push(b'C');
774 p.extend_from_slice(b"28P01\0");
775 p.push(b'M');
776 p.extend_from_slice(b"password authentication failed\0");
777 p.push(0);
778 assert_eq!(error_message(&p), "password authentication failed");
779 }
780
781 #[test]
782 fn test_parse_parameter_status() {
783 let mut p = Vec::new();
784 p.extend_from_slice(b"client_encoding\0");
785 p.extend_from_slice(b"UTF8\0");
786 let (k, v) = parse_parameter_status(&p).unwrap();
787 assert_eq!(k, "client_encoding");
788 assert_eq!(v, "UTF8");
789 }
790
791 #[test]
792 fn test_parse_sasl_mechanisms() {
793 let mut p = Vec::new();
794 p.extend_from_slice(b"SCRAM-SHA-256\0");
795 p.extend_from_slice(b"SCRAM-SHA-256-PLUS\0");
796 p.push(0);
797 let m = parse_sasl_mechanisms(&p);
798 assert_eq!(m.len(), 2);
799 assert_eq!(m[0], "SCRAM-SHA-256");
800 assert_eq!(m[1], "SCRAM-SHA-256-PLUS");
801 }
802
803 #[test]
804 fn test_generate_nonce_is_url_safe() {
805 let n = generate_nonce();
806 assert!(n.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
807 assert!(n.len() >= 18);
808 }
809
810 #[test]
811 fn test_query_result_rows_affected() {
812 let r = QueryResult {
813 columns: Vec::new(),
814 rows: Vec::new(),
815 command_tag: "INSERT 0 5".into(),
816 };
817 assert_eq!(r.rows_affected(), Some(5));
818 let r = QueryResult {
819 columns: Vec::new(),
820 rows: Vec::new(),
821 command_tag: "SET".into(),
822 };
823 assert_eq!(r.rows_affected(), None);
824 }
825}