use std::io::{Read, Write};
use bytes::BytesMut;
use tracing::{debug, info, trace, warn};
use crate::protocol::message::{backend::Message, frontend};
use super::auth::{self, AuthState};
use super::error::{Error, Result};
pub const POST_ERROR_DRAIN_CAP: usize = 1024;
#[derive(Debug)]
pub struct RawConnection<S> {
stream: S,
read_buf: BytesMut,
write_buf: BytesMut,
process_id: i32,
secret_key: i32,
server_params: std::collections::HashMap<String, String>,
desynchronized: bool,
}
impl<S> RawConnection<S>
where
S: Read + Write,
{
pub fn new(stream: S) -> Self {
RawConnection {
stream,
read_buf: BytesMut::with_capacity(64 * 1024),
write_buf: BytesMut::with_capacity(64 * 1024),
process_id: 0,
secret_key: 0,
server_params: std::collections::HashMap::new(),
desynchronized: false,
}
}
pub fn is_healthy(&self) -> bool {
!self.desynchronized
}
pub(crate) fn ensure_healthy(&self) -> Result<()> {
if self.desynchronized {
return Err(Error::new(
crate::client::error::ErrorKind::Connection,
"connection is desynchronized from the server and cannot be reused; \
discard it and open a new one",
));
}
Ok(())
}
pub fn reserve_write_buffer(&mut self, additional: usize) {
self.write_buf.reserve(additional);
}
pub fn process_id(&self) -> i32 {
self.process_id
}
pub fn secret_key(&self) -> i32 {
self.secret_key
}
pub fn stream(&self) -> &S {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut S {
&mut self.stream
}
pub fn parameter_status(&self, name: &str) -> Option<&str> {
self.server_params
.get(name)
.map(std::string::String::as_str)
}
pub fn startup(&mut self, params: &[(&str, &str)], password: Option<&str>) -> Result<()> {
frontend::startup_message(params, &mut self.write_buf)?;
self.flush()?;
let mut auth_state: Option<AuthState> = None;
loop {
let msg = self.read_message()?;
match msg {
Message::AuthenticationOk => {
info!(target: "hyperdb_api", "connection-auth-success");
}
Message::AuthenticationCleartextPassword => {
debug!(target: "hyperdb_api", method = "cleartext", "connection-auth-method");
let password = password.ok_or_else(|| {
Error::authentication(
"server requested cleartext password but none provided",
)
})?;
frontend::password_message(password, &mut self.write_buf)?;
self.flush()?;
}
Message::AuthenticationMd5Password(body) => {
debug!(target: "hyperdb_api", method = "MD5", "connection-auth-method");
let password = password.ok_or_else(|| {
Error::authentication("server requested MD5 password but none provided")
})?;
let user = params
.iter()
.find(|(k, _)| *k == "user")
.map_or("", |(_, v)| *v);
let md5_response = auth::compute_md5_password(user, password, &body.salt());
frontend::password_message(&md5_response, &mut self.write_buf)?;
self.flush()?;
}
Message::AuthenticationSasl(body) => {
debug!(target: "hyperdb_api", method = "SCRAM-SHA-256", "connection-auth-method");
let password = password.ok_or_else(|| {
Error::authentication(
"server requested SASL authentication but no password provided",
)
})?;
let mechanisms: Vec<&str> = body.mechanisms().collect();
if !mechanisms.contains(&"SCRAM-SHA-256") {
return Err(Error::authentication(format!(
"server offered unsupported SASL mechanisms: {mechanisms:?}"
)));
}
let (state, client_first) = auth::scram_client_first(password)?;
auth_state = Some(state);
frontend::sasl_initial_response(
"SCRAM-SHA-256",
&client_first,
&mut self.write_buf,
)?;
self.flush()?;
}
Message::AuthenticationSaslContinue(body) => {
let state = auth_state.take().ok_or_else(|| {
Error::authentication("received SASL continue without initial state")
})?;
let server_first = body.data();
let (new_state, client_final) = auth::scram_client_final(state, server_first)?;
auth_state = Some(new_state);
frontend::sasl_response(&client_final, &mut self.write_buf)?;
self.flush()?;
}
Message::AuthenticationSaslFinal(body) => {
let state = auth_state.take().ok_or_else(|| {
Error::authentication("received SASL final without state")
})?;
auth::scram_verify_server(state, body.data())?;
}
Message::BackendKeyData(data) => {
self.process_id = data.process_id();
self.secret_key = data.secret_key();
}
Message::ParameterStatus(body) => {
if let (Ok(name), Ok(value)) = (body.name(), body.value()) {
self.server_params
.insert(name.to_string(), value.to_string());
}
}
Message::ReadyForQuery(_) => {
return Ok(());
}
Message::ErrorResponse(body) => {
return Err(self.consume_error(&body));
}
_ => {
return Err(Error::protocol("unexpected message during startup"));
}
}
}
}
pub fn simple_query(&mut self, query: &str) -> Result<Vec<Message>> {
self.ensure_healthy()?;
frontend::query(query, &mut self.write_buf)?;
self.flush()?;
let mut messages = Vec::new();
loop {
let msg = self.read_message()?;
match &msg {
Message::ReadyForQuery(_) => {
messages.push(msg);
return Ok(messages);
}
Message::ErrorResponse(body) => {
return Err(self.consume_error(body));
}
_ => {
messages.push(msg);
}
}
}
}
pub fn query_binary(&mut self, query: &str) -> Result<Vec<Message>> {
self.ensure_healthy()?;
const HYPER_BINARY_FORMAT: i16 = 2;
frontend::parse("", query, &[], &mut self.write_buf)?;
frontend::bind(
"",
"",
&[],
&[],
&[HYPER_BINARY_FORMAT],
&mut self.write_buf,
)?;
frontend::describe(b'P', "", &mut self.write_buf)?;
frontend::execute("", 0, &mut self.write_buf)?;
frontend::sync(&mut self.write_buf);
self.flush()?;
let mut messages = Vec::new();
loop {
let msg = self.read_message()?;
match &msg {
Message::ReadyForQuery(_) => {
messages.push(msg);
return Ok(messages);
}
Message::ErrorResponse(body) => {
return Err(self.consume_error(body));
}
_ => {
messages.push(msg);
}
}
}
}
pub fn start_query_binary(&mut self, query: &str) -> Result<()> {
self.ensure_healthy()?;
const HYPER_BINARY_FORMAT: i16 = 2;
frontend::parse("", query, &[], &mut self.write_buf)?;
frontend::bind(
"",
"",
&[],
&[],
&[HYPER_BINARY_FORMAT],
&mut self.write_buf,
)?;
frontend::describe(b'P', "", &mut self.write_buf)?;
frontend::execute("", 0, &mut self.write_buf)?;
frontend::sync(&mut self.write_buf);
self.flush()
}
pub fn start_simple_query(&mut self, query: &str) -> Result<()> {
self.ensure_healthy()?;
frontend::query(query, &mut self.write_buf)?;
self.flush()
}
pub fn start_execute_prepared(
&mut self,
statement_name: &str,
params: &[Option<&[u8]>],
column_count: usize,
) -> Result<()> {
self.ensure_healthy()?;
const PG_BINARY_FORMAT: i16 = 1;
const HYPER_BINARY_FORMAT: i16 = 2;
let param_formats: Vec<i16> = vec![PG_BINARY_FORMAT; params.len()];
let result_formats: Vec<i16> = vec![HYPER_BINARY_FORMAT; column_count];
frontend::bind(
"", statement_name,
¶m_formats,
params,
&result_formats,
&mut self.write_buf,
)?;
frontend::execute("", 0, &mut self.write_buf)?;
frontend::sync(&mut self.write_buf);
self.flush()
}
pub fn read_message(&mut self) -> Result<Message> {
loop {
if let Some(msg) = Message::parse(&mut self.read_buf).map_err(Error::io)? {
return Ok(msg);
}
let prev_len = self.read_buf.len();
self.read_buf.resize(prev_len + 64 * 1024, 0);
let n = self.stream.read(&mut self.read_buf[prev_len..])?;
if n == 0 {
self.read_buf.truncate(prev_len);
warn!(target: "hyperdb_api", "connection-closed");
return Err(Error::closed());
}
self.read_buf.truncate(prev_len + n);
}
}
pub fn drain_until_ready(&mut self) {
let _ = self.drain_until_ready_bounded(usize::MAX);
}
pub fn drain_until_ready_bounded(&mut self, max_messages: usize) -> bool {
for i in 0..max_messages {
match self.read_message() {
Ok(Message::ReadyForQuery(_)) => return true,
Ok(_) => {}
Err(e) => {
warn!(
target: "hyperdb_api_core::client",
error = %e,
messages_read = i,
"drain_until_ready: read error mid-drain (likely closed connection); \
connection marked desynchronized",
);
self.desynchronized = true;
return false;
}
}
}
warn!(
target: "hyperdb_api_core::client",
max_messages,
"drain_until_ready_bounded: exhausted budget without seeing ReadyForQuery; \
connection marked desynchronized and should not be reused",
);
self.desynchronized = true;
false
}
pub fn consume_error(
&mut self,
body: &crate::protocol::message::backend::ErrorResponseBody,
) -> Error {
let err = parse_error_response(body);
let _ = self.drain_until_ready_bounded(POST_ERROR_DRAIN_CAP);
err
}
pub fn flush(&mut self) -> Result<()> {
if !self.write_buf.is_empty() {
self.stream.write_all(&self.write_buf)?;
self.stream.flush()?;
self.write_buf.clear();
}
Ok(())
}
pub fn terminate(&mut self) -> Result<()> {
frontend::terminate(&mut self.write_buf);
self.flush()
}
pub fn write_buf(&mut self) -> &mut BytesMut {
&mut self.write_buf
}
pub fn start_copy_in(&mut self, table_name: &str, columns: &[&str]) -> Result<()> {
self.start_copy_in_with_format(table_name, columns, "HYPERBINARY")
}
pub fn start_copy_in_with_format(
&mut self,
table_name: &str,
columns: &[&str],
format: &str,
) -> Result<()> {
self.ensure_healthy()?;
let column_list = if columns.is_empty() {
String::new()
} else {
format!(
" ({})",
columns
.iter()
.map(|c| format!("\"{}\"", c.replace('"', "\"\"")))
.collect::<Vec<_>>()
.join(", ")
)
};
let query = format!("COPY {table_name}{column_list} FROM STDIN WITH (FORMAT {format})");
frontend::query(&query, &mut self.write_buf)?;
self.flush()?;
loop {
let msg = self.read_message()?;
match msg {
Message::CopyInResponse(_) => {
return Ok(());
}
Message::ErrorResponse(body) => {
return Err(self.consume_error(&body));
}
_ => {
}
}
}
}
pub fn start_copy_in_raw(&mut self, query: &str) -> Result<()> {
self.ensure_healthy()?;
frontend::query(query, &mut self.write_buf)?;
self.flush()?;
loop {
let msg = self.read_message()?;
match msg {
Message::CopyInResponse(_) => {
return Ok(());
}
Message::ErrorResponse(body) => {
return Err(self.consume_error(&body));
}
_ => {}
}
}
}
pub fn send_copy_data(&mut self, data: &[u8]) -> Result<()> {
frontend::copy_data(data, &mut self.write_buf);
Ok(())
}
pub fn send_copy_data_direct(&mut self, data: &[u8]) -> Result<()> {
if !self.write_buf.is_empty() {
self.stream.write_all(&self.write_buf)?;
self.write_buf.clear();
}
let msg_len = u32::try_from(4 + data.len())
.map_err(|_| Error::protocol("CopyData payload exceeds u32::MAX bytes"))?;
let len_be = msg_len.to_be_bytes();
let header = [b'd', len_be[0], len_be[1], len_be[2], len_be[3]];
self.stream.write_all(&header)?;
self.stream.write_all(data)?;
Ok(())
}
pub fn flush_stream(&mut self) -> Result<()> {
self.stream.flush()?;
Ok(())
}
pub fn finish_copy(&mut self) -> Result<u64> {
self.flush()?;
frontend::copy_done(&mut self.write_buf);
self.flush()?;
let mut row_count = 0u64;
loop {
let msg = self.read_message()?;
match msg {
Message::CommandComplete(body) => {
if let Ok(tag) = body.tag() {
if let Some(count_str) = tag.strip_prefix("COPY ") {
if let Ok(count) = count_str.trim().parse() {
row_count = count;
}
}
}
}
Message::ReadyForQuery(_) => {
return Ok(row_count);
}
Message::ErrorResponse(body) => {
return Err(self.consume_error(&body));
}
_ => {
}
}
}
}
pub fn cancel_copy(&mut self, reason: &str) -> Result<()> {
self.flush()?;
frontend::copy_fail(reason, &mut self.write_buf);
self.flush()?;
loop {
let msg = self.read_message()?;
match msg {
Message::ReadyForQuery(_) => {
return Ok(());
}
Message::ErrorResponse(_) => {
}
_ => {
}
}
}
}
pub fn copy_out(&mut self, query: &str) -> Result<Vec<u8>> {
self.ensure_healthy()?;
frontend::query(query, &mut self.write_buf)?;
self.flush()?;
let mut data = Vec::new();
let mut in_copy_out = false;
loop {
let msg = self.read_message()?;
match msg {
Message::CopyOutResponse(_) => {
in_copy_out = true;
}
Message::CopyData(body) if in_copy_out => {
data.extend_from_slice(body.data());
}
Message::CopyDone => {
in_copy_out = false;
}
Message::CommandComplete(_) => {
}
Message::ReadyForQuery(_) => {
return Ok(data);
}
Message::ErrorResponse(body) => {
return Err(self.consume_error(&body));
}
_ => {
}
}
}
}
pub fn copy_out_to_writer(
&mut self,
query: &str,
writer: &mut dyn std::io::Write,
) -> Result<u64> {
self.ensure_healthy()?;
frontend::query(query, &mut self.write_buf)?;
self.flush()?;
let mut total_bytes: u64 = 0;
let mut in_copy_out = false;
loop {
let msg = self.read_message()?;
match msg {
Message::CopyOutResponse(_) => {
in_copy_out = true;
}
Message::CopyData(body) if in_copy_out => {
let chunk = body.data();
writer.write_all(chunk).map_err(|e| {
Error::new(
super::error::ErrorKind::Io,
format!("Failed to write COPY data: {e}"),
)
})?;
total_bytes += chunk.len() as u64;
}
Message::CopyDone => {
in_copy_out = false;
}
Message::CommandComplete(_) => {}
Message::ReadyForQuery(_) => {
return Ok(total_bytes);
}
Message::ErrorResponse(body) => {
return Err(self.consume_error(&body));
}
_ => {}
}
}
}
}
pub(crate) fn parse_error_response(
body: &crate::protocol::message::backend::ErrorResponseBody,
) -> Error {
let mut severity = String::from("ERROR");
let mut code = String::from("00000");
let mut message = String::from("unknown error");
for field in body.fields().filter_map(|r| {
r.map_err(|e| trace!(target: "hyperdb_api_core::client", error = %e, "dropped error parsing error response field")).ok()
}) {
match field.type_() {
b'S' | b'V' => {
if let Ok(s) = field.value() {
severity = s.to_string();
}
}
b'C' => {
if let Ok(s) = field.value() {
code = s.to_string();
}
}
b'M' => {
if let Ok(s) = field.value() {
message = s.to_string();
}
}
_ => {}
}
}
Error::db(&severity, &code, &message)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
struct NullStream;
impl std::io::Read for NullStream {
fn read(&mut self, _: &mut [u8]) -> std::io::Result<usize> {
Ok(0)
}
}
impl std::io::Write for NullStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[test]
fn fresh_connection_is_healthy() {
let conn = RawConnection::new(NullStream);
assert!(conn.is_healthy());
assert!(conn.ensure_healthy().is_ok());
}
#[test]
fn desynchronized_connection_fails_health_check() {
let mut conn = RawConnection::new(NullStream);
conn.desynchronized = true;
assert!(!conn.is_healthy());
let err = conn.ensure_healthy().expect_err("must fail-fast");
assert_eq!(err.kind(), crate::client::error::ErrorKind::Connection);
assert!(
err.to_string().to_lowercase().contains("desynchron"),
"error message should mention desynchronization; got: {err}",
);
}
#[test]
fn zero_budget_drain_marks_desynchronized() {
let mut conn = RawConnection::new(Cursor::new(Vec::<u8>::new()));
assert!(conn.is_healthy());
let ok = conn.drain_until_ready_bounded(0);
assert!(!ok, "zero-budget drain must return false");
assert!(
!conn.is_healthy(),
"drain failure must mark connection desynchronized",
);
}
#[test]
fn desynchronized_connection_fast_fails_simple_query() {
let mut conn = RawConnection::new(Cursor::new(Vec::<u8>::new()));
conn.desynchronized = true;
let Err(err) = conn.simple_query("SELECT 1") else {
panic!("desynced simple_query must fail-fast")
};
assert_eq!(err.kind(), crate::client::error::ErrorKind::Connection);
assert!(err.to_string().to_lowercase().contains("desynchron"));
}
}