use super::auth::{md5_password_response, Scram};
use super::error::{BackendError, BackendResult};
use super::stream::Stream;
use super::tls::{negotiate, TlsMode};
use super::types::{encode_literal, ParamValue, TextValue};
use crate::protocol::{Message, MessageType, ProtocolCodec};
use bytes::{Buf, BufMut, BytesMut};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
#[derive(Debug, Clone)]
pub struct BackendConfig {
pub host: String,
pub port: u16,
pub user: String,
pub password: Option<String>,
pub database: Option<String>,
pub application_name: Option<String>,
pub tls_mode: TlsMode,
pub connect_timeout: Duration,
pub query_timeout: Duration,
pub tls_config: Arc<rustls::ClientConfig>,
}
impl BackendConfig {
pub fn address(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}
pub struct BackendClient {
stream: Stream,
pub server_parameters: std::collections::HashMap<String, String>,
pub backend_pid: Option<u32>,
pub backend_secret: Option<u32>,
}
impl BackendClient {
pub async fn connect(cfg: &BackendConfig) -> BackendResult<Self> {
tokio::time::timeout(cfg.connect_timeout, Self::connect_inner(cfg))
.await
.map_err(|_| BackendError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("connect to {} exceeded {:?}", cfg.address(), cfg.connect_timeout),
)))?
}
async fn connect_inner(cfg: &BackendConfig) -> BackendResult<Self> {
let tcp = TcpStream::connect(cfg.address()).await?;
let mut stream =
negotiate(tcp, cfg.tls_mode, cfg.tls_config.clone(), &cfg.host).await?;
let startup = build_startup(cfg);
stream.write_all(&startup).await?;
let mut server_parameters = std::collections::HashMap::new();
let mut backend_pid = None;
let mut backend_secret = None;
let mut buffer = BytesMut::with_capacity(4096);
let codec = ProtocolCodec::new();
let mut scram_state: Option<Scram> = None;
loop {
let msg = read_one(&mut stream, &mut buffer, &codec).await?;
match msg.msg_type {
MessageType::AuthRequest => {
handle_auth(
&mut stream,
&msg,
cfg,
&mut scram_state,
)
.await?;
}
MessageType::ParameterStatus => {
if let Some((k, v)) = parse_parameter_status(&msg.payload) {
server_parameters.insert(k, v);
}
}
MessageType::BackendKeyData => {
if msg.payload.len() >= 8 {
backend_pid = Some(u32::from_be_bytes(
msg.payload[0..4].try_into().unwrap(),
));
backend_secret = Some(u32::from_be_bytes(
msg.payload[4..8].try_into().unwrap(),
));
}
}
MessageType::ReadyForQuery => {
return Ok(Self {
stream,
server_parameters,
backend_pid,
backend_secret,
});
}
MessageType::ErrorResponse => {
return Err(BackendError::BackendError(error_message(&msg.payload)));
}
MessageType::NoticeResponse => {
}
other => {
return Err(BackendError::Protocol(format!(
"unexpected message during startup: {:?}",
other
)));
}
}
}
}
pub async fn simple_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
self.run_query(sql).await
}
pub async fn query_with_params(
&mut self,
sql: &str,
params: &[ParamValue],
) -> BackendResult<QueryResult> {
let substituted = interpolate_params(sql, params)?;
self.run_query(&substituted).await
}
pub async fn query_scalar(&mut self, sql: &str) -> BackendResult<TextValue> {
let res = self.simple_query(sql).await?;
if res.rows.len() != 1 {
return Err(BackendError::Protocol(format!(
"expected 1 row, got {}",
res.rows.len()
)));
}
if res.columns.len() != 1 {
return Err(BackendError::Protocol(format!(
"expected 1 column, got {}",
res.columns.len()
)));
}
Ok(res.rows.into_iter().next().unwrap().into_iter().next().unwrap())
}
pub async fn execute(&mut self, sql: &str) -> BackendResult<String> {
let res = self.simple_query(sql).await?;
Ok(res.command_tag)
}
async fn run_query(&mut self, sql: &str) -> BackendResult<QueryResult> {
let t = self.stream_query_timeout();
tokio::time::timeout(t, Self::run_query_inner(&mut self.stream, sql))
.await
.map_err(|_| BackendError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("query exceeded {:?}: {}", t, truncate(sql, 64)),
)))?
}
fn stream_query_timeout(&self) -> Duration {
Duration::from_secs(30)
}
async fn run_query_inner(stream: &mut Stream, sql: &str) -> BackendResult<QueryResult> {
let mut payload = BytesMut::with_capacity(sql.len() + 1);
payload.extend_from_slice(sql.as_bytes());
payload.put_u8(0);
let frame = Message::new(MessageType::Query, payload).encode();
stream.write_all(&frame).await?;
let mut buffer = BytesMut::with_capacity(8192);
let codec = ProtocolCodec::new();
let mut columns: Vec<ColumnMeta> = Vec::new();
let mut rows: Vec<Vec<TextValue>> = Vec::new();
let mut command_tag = String::new();
let mut last_error: Option<String> = None;
loop {
let msg = read_one(stream, &mut buffer, &codec).await?;
match msg.msg_type {
MessageType::RowDescription => {
columns = parse_row_description(&msg.payload);
}
MessageType::DataRow => {
let row = parse_data_row(&msg.payload, columns.len())?;
rows.push(row);
}
MessageType::CommandComplete | MessageType::Close => {
command_tag = parse_cstring(&msg.payload);
}
MessageType::EmptyQueryResponse => {
command_tag = String::new();
}
MessageType::ErrorResponse => {
last_error = Some(error_message(&msg.payload));
}
MessageType::NoticeResponse => {
tracing::debug!(notice = %error_message(&msg.payload), "backend notice");
}
MessageType::ReadyForQuery => {
if let Some(e) = last_error {
return Err(BackendError::BackendError(e));
}
return Ok(QueryResult {
columns,
rows,
command_tag,
});
}
MessageType::ParameterStatus => {
}
_other => {
}
}
}
}
pub async fn close(mut self) {
let term = Message::new(MessageType::Terminate, BytesMut::new()).encode();
let _ = self.stream.write_all(&term).await;
let _ = self.stream.shutdown().await;
}
pub fn is_tls(&self) -> bool {
self.stream.is_tls()
}
}
#[derive(Debug, Clone)]
pub struct ColumnMeta {
pub name: String,
pub type_oid: u32,
}
#[derive(Debug, Clone)]
pub struct QueryResult {
pub columns: Vec<ColumnMeta>,
pub rows: Vec<Vec<TextValue>>,
pub command_tag: String,
}
impl QueryResult {
pub fn rows_affected(&self) -> Option<u64> {
self.command_tag
.split_whitespace()
.last()
.and_then(|s| s.parse().ok())
}
}
fn build_startup(cfg: &BackendConfig) -> Vec<u8> {
let mut payload = BytesMut::with_capacity(128);
payload.put_u32(196608);
put_cstring(&mut payload, "user");
put_cstring(&mut payload, &cfg.user);
if let Some(db) = &cfg.database {
put_cstring(&mut payload, "database");
put_cstring(&mut payload, db);
}
put_cstring(&mut payload, "application_name");
put_cstring(
&mut payload,
cfg.application_name
.as_deref()
.unwrap_or("heliosdb-proxy"),
);
put_cstring(&mut payload, "client_encoding");
put_cstring(&mut payload, "UTF8");
payload.put_u8(0);
let mut framed = BytesMut::with_capacity(payload.len() + 4);
framed.put_u32((payload.len() + 4) as u32);
framed.extend_from_slice(&payload);
framed.to_vec()
}
fn put_cstring(buf: &mut BytesMut, s: &str) {
buf.extend_from_slice(s.as_bytes());
buf.put_u8(0);
}
fn parse_cstring(payload: &[u8]) -> String {
let end = payload.iter().position(|&b| b == 0).unwrap_or(payload.len());
String::from_utf8_lossy(&payload[..end]).into_owned()
}
fn parse_parameter_status(payload: &[u8]) -> Option<(String, String)> {
let end1 = payload.iter().position(|&b| b == 0)?;
let key = String::from_utf8_lossy(&payload[..end1]).into_owned();
let rest = &payload[end1 + 1..];
let end2 = rest.iter().position(|&b| b == 0).unwrap_or(rest.len());
let value = String::from_utf8_lossy(&rest[..end2]).into_owned();
Some((key, value))
}
fn parse_row_description(payload: &[u8]) -> Vec<ColumnMeta> {
let mut p = BytesMut::from(payload);
if p.remaining() < 2 {
return Vec::new();
}
let n = p.get_u16() as usize;
let mut cols = Vec::with_capacity(n);
for _ in 0..n {
let end = match p.as_ref().iter().position(|&b| b == 0) {
Some(i) => i,
None => break,
};
let name = String::from_utf8_lossy(&p.as_ref()[..end]).into_owned();
p.advance(end + 1);
if p.remaining() < 18 {
break;
}
let _table_oid = p.get_u32();
let _column_number = p.get_u16();
let type_oid = p.get_u32();
let _type_len = p.get_i16();
let _type_mod = p.get_i32();
let _format_code = p.get_u16();
cols.push(ColumnMeta { name, type_oid });
}
cols
}
fn parse_data_row(payload: &[u8], column_count: usize) -> BackendResult<Vec<TextValue>> {
let mut p = BytesMut::from(payload);
if p.remaining() < 2 {
return Err(BackendError::Protocol("truncated DataRow".into()));
}
let n = p.get_u16() as usize;
let mut out = Vec::with_capacity(n);
for _ in 0..n {
if p.remaining() < 4 {
return Err(BackendError::Protocol("truncated DataRow field".into()));
}
let len = p.get_i32();
if len == -1 {
out.push(TextValue::Null);
} else {
let len = len as usize;
if p.remaining() < len {
return Err(BackendError::Protocol(
"truncated DataRow value".into(),
));
}
let bytes = p.split_to(len);
out.push(TextValue::Text(
String::from_utf8_lossy(&bytes).into_owned(),
));
}
}
let _ = column_count;
Ok(out)
}
fn error_message(payload: &[u8]) -> String {
let mut i = 0;
let mut msg_field = None;
while i < payload.len() {
let code = payload[i];
if code == 0 {
break;
}
i += 1;
let end = match payload[i..].iter().position(|&b| b == 0) {
Some(e) => i + e,
None => payload.len(),
};
let value = String::from_utf8_lossy(&payload[i..end]).into_owned();
if code == b'M' {
msg_field = Some(value);
}
i = end + 1;
}
msg_field.unwrap_or_else(|| "<no message>".to_string())
}
async fn read_one(
stream: &mut Stream,
buffer: &mut BytesMut,
codec: &ProtocolCodec,
) -> BackendResult<Message> {
loop {
if let Some(msg) = codec
.decode_message(buffer)
.map_err(|e| BackendError::Protocol(e.to_string()))?
{
return Ok(msg);
}
let mut tmp = vec![0u8; 4096];
let n = stream.read(&mut tmp).await?;
if n == 0 {
return Err(BackendError::Closed);
}
buffer.extend_from_slice(&tmp[..n]);
}
}
async fn handle_auth(
stream: &mut Stream,
msg: &Message,
cfg: &BackendConfig,
scram_state: &mut Option<Scram>,
) -> BackendResult<()> {
if msg.payload.len() < 4 {
return Err(BackendError::Protocol(
"AuthRequest payload < 4 bytes".into(),
));
}
let code =
u32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
match code {
0 => Ok(()), 5 => {
if msg.payload.len() < 8 {
return Err(BackendError::Protocol("AuthenticationMD5 truncated".into()));
}
let salt: [u8; 4] = [
msg.payload[4],
msg.payload[5],
msg.payload[6],
msg.payload[7],
];
let password = cfg.password.as_deref().ok_or_else(|| {
BackendError::Auth("server requested MD5 but no password configured".into())
})?;
let payload = md5_password_response(&cfg.user, password, &salt);
write_password_message(stream, &payload).await
}
3 => {
let password = cfg.password.as_deref().ok_or_else(|| {
BackendError::Auth("server requested password but none configured".into())
})?;
let mut payload = Vec::with_capacity(password.len() + 1);
payload.extend_from_slice(password.as_bytes());
payload.push(0);
write_password_message(stream, &payload).await
}
10 => {
let mechs = parse_sasl_mechanisms(&msg.payload[4..]);
if !mechs.iter().any(|m| m == "SCRAM-SHA-256") {
return Err(BackendError::Auth(format!(
"no supported SASL mechanism; server offered {:?}",
mechs
)));
}
let nonce = generate_nonce();
let (scram, first) = Scram::client_first(nonce);
*scram_state = Some(scram);
write_password_message(stream, &first.0).await
}
11 => {
let scram = scram_state.as_mut().ok_or_else(|| {
BackendError::Auth("SASLContinue before SASL start".into())
})?;
let password = cfg.password.as_deref().ok_or_else(|| {
BackendError::Auth("SCRAM requires a password".into())
})?;
let out = scram.client_final(&msg.payload[4..], password)?;
write_password_message(stream, &out.0).await
}
12 => {
let scram = scram_state.as_ref().ok_or_else(|| {
BackendError::Auth("SASLFinal before SASL start".into())
})?;
scram.verify_server(&msg.payload[4..])
}
other => Err(BackendError::Auth(format!(
"unsupported authentication request code: {}",
other
))),
}
}
async fn write_password_message(
stream: &mut Stream,
payload: &[u8],
) -> BackendResult<()> {
let mut buf = BytesMut::with_capacity(payload.len() + 5);
buf.put_u8(b'p');
buf.put_u32((payload.len() + 4) as u32);
buf.extend_from_slice(payload);
stream.write_all(&buf).await?;
Ok(())
}
fn parse_sasl_mechanisms(payload: &[u8]) -> Vec<String> {
let mut out = Vec::new();
let mut i = 0;
while i < payload.len() {
let end = match payload[i..].iter().position(|&b| b == 0) {
Some(e) => i + e,
None => payload.len(),
};
if end == i {
break; }
out.push(String::from_utf8_lossy(&payload[i..end]).into_owned());
i = end + 1;
}
out
}
fn generate_nonce() -> String {
use base64::Engine as _;
use rand::RngCore;
let mut bytes = [0u8; 18];
rand::thread_rng().fill_bytes(&mut bytes);
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
fn interpolate_params(sql: &str, params: &[ParamValue]) -> BackendResult<String> {
let mut out = String::with_capacity(sql.len());
let bytes = sql.as_bytes();
let mut i = 0;
let mut in_string = false;
let mut quote_char = 0u8;
while i < bytes.len() {
let b = bytes[i];
if in_string {
out.push(b as char);
if b == quote_char {
if i + 1 < bytes.len() && bytes[i + 1] == quote_char {
out.push(quote_char as char);
i += 2;
continue;
}
in_string = false;
}
i += 1;
continue;
}
if b == b'\'' || b == b'"' {
in_string = true;
quote_char = b;
out.push(b as char);
i += 1;
continue;
}
if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
let mut j = i + 1;
while j < bytes.len() && bytes[j].is_ascii_digit() {
j += 1;
}
let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
.unwrap()
.parse()
.map_err(|_| {
BackendError::Protocol(format!(
"invalid parameter reference at byte {}",
i
))
})?;
if idx == 0 || idx > params.len() {
return Err(BackendError::Protocol(format!(
"parameter ${} out of range (have {})",
idx,
params.len()
)));
}
out.push_str(&encode_literal(¶ms[idx - 1]));
i = j;
continue;
}
out.push(b as char);
i += 1;
}
Ok(out)
}
fn truncate(s: &str, n: usize) -> &str {
match s.char_indices().nth(n) {
Some((i, _)) => &s[..i],
None => s,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::types::ParamValue;
#[test]
fn test_build_startup_has_user_and_protocol_version() {
let cfg = BackendConfig {
host: "localhost".into(),
port: 5432,
user: "alice".into(),
password: None,
database: Some("app".into()),
application_name: None,
tls_mode: TlsMode::Disable,
connect_timeout: Duration::from_secs(5),
query_timeout: Duration::from_secs(5),
tls_config: crate::backend::tls::default_client_config(),
};
let bytes = build_startup(&cfg);
assert_eq!(
u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]),
196608
);
assert!(bytes
.windows(5)
.any(|w| w == b"user\0"));
assert!(bytes
.windows(10)
.any(|w| w == b"database\0a"));
}
#[test]
fn test_interpolate_params_basic() {
let params = vec![
ParamValue::Int(42),
ParamValue::Text("alice".into()),
];
let sql = "SELECT * FROM t WHERE id = $1 AND name = $2";
let out = interpolate_params(sql, ¶ms).unwrap();
assert_eq!(out, "SELECT * FROM t WHERE id = 42 AND name = 'alice'");
}
#[test]
fn test_interpolate_params_escapes_quotes() {
let params = vec![ParamValue::Text("o'brien".into())];
let out =
interpolate_params("SELECT * FROM t WHERE name = $1", ¶ms).unwrap();
assert_eq!(out, "SELECT * FROM t WHERE name = 'o''brien'");
}
#[test]
fn test_interpolate_params_leaves_dollar_in_string_alone() {
let params = vec![ParamValue::Int(1)];
let sql = "SELECT '$1' AS lit, $1 AS val";
let out = interpolate_params(sql, ¶ms).unwrap();
assert_eq!(out, "SELECT '$1' AS lit, 1 AS val");
}
#[test]
fn test_interpolate_params_out_of_range() {
let params = vec![ParamValue::Int(1)];
let err = interpolate_params("SELECT $2", ¶ms).unwrap_err();
assert!(matches!(err, BackendError::Protocol(_)));
}
#[test]
fn test_parse_row_description_shape() {
let mut p = BytesMut::new();
p.put_u16(1);
p.extend_from_slice(b"x");
p.put_u8(0);
p.put_u32(0); p.put_u16(0); p.put_u32(23); p.put_i16(4);
p.put_i32(-1);
p.put_u16(0);
let cols = parse_row_description(&p);
assert_eq!(cols.len(), 1);
assert_eq!(cols[0].name, "x");
assert_eq!(cols[0].type_oid, 23);
}
#[test]
fn test_parse_data_row_with_null() {
let mut p = BytesMut::new();
p.put_u16(2);
p.put_i32(1);
p.extend_from_slice(b"a");
p.put_i32(-1);
let row = parse_data_row(&p, 2).unwrap();
assert_eq!(row.len(), 2);
assert_eq!(row[0], TextValue::Text("a".into()));
assert_eq!(row[1], TextValue::Null);
}
#[test]
fn test_error_message_extracts_m_field() {
let mut p = Vec::new();
p.push(b'S');
p.extend_from_slice(b"ERROR\0");
p.push(b'C');
p.extend_from_slice(b"28P01\0");
p.push(b'M');
p.extend_from_slice(b"password authentication failed\0");
p.push(0);
assert_eq!(error_message(&p), "password authentication failed");
}
#[test]
fn test_parse_parameter_status() {
let mut p = Vec::new();
p.extend_from_slice(b"client_encoding\0");
p.extend_from_slice(b"UTF8\0");
let (k, v) = parse_parameter_status(&p).unwrap();
assert_eq!(k, "client_encoding");
assert_eq!(v, "UTF8");
}
#[test]
fn test_parse_sasl_mechanisms() {
let mut p = Vec::new();
p.extend_from_slice(b"SCRAM-SHA-256\0");
p.extend_from_slice(b"SCRAM-SHA-256-PLUS\0");
p.push(0);
let m = parse_sasl_mechanisms(&p);
assert_eq!(m.len(), 2);
assert_eq!(m[0], "SCRAM-SHA-256");
assert_eq!(m[1], "SCRAM-SHA-256-PLUS");
}
#[test]
fn test_generate_nonce_is_url_safe() {
let n = generate_nonce();
assert!(n.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
assert!(n.len() >= 18);
}
#[test]
fn test_query_result_rows_affected() {
let r = QueryResult {
columns: Vec::new(),
rows: Vec::new(),
command_tag: "INSERT 0 5".into(),
};
assert_eq!(r.rows_affected(), Some(5));
let r = QueryResult {
columns: Vec::new(),
rows: Vec::new(),
command_tag: "SET".into(),
};
assert_eq!(r.rows_affected(), None);
}
}