use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, Weak};
use crate::protocol::message::{backend::Message, frontend};
use crate::types::Oid;
use tracing::{trace, warn};
use super::connection::RawConnection;
use super::error::{Error, Result};
use super::row::Row;
use super::statement::Column;
use super::sync_stream::SyncStream;
pub trait SqlParam {
fn encode(&self) -> Vec<u8>;
}
impl SqlParam for i16 {
#[inline]
fn encode(&self) -> Vec<u8> {
self.to_le_bytes().to_vec()
}
}
impl SqlParam for i32 {
#[inline]
fn encode(&self) -> Vec<u8> {
self.to_le_bytes().to_vec()
}
}
impl SqlParam for i64 {
#[inline]
fn encode(&self) -> Vec<u8> {
self.to_le_bytes().to_vec()
}
}
impl SqlParam for f32 {
#[inline]
fn encode(&self) -> Vec<u8> {
self.to_le_bytes().to_vec()
}
}
impl SqlParam for f64 {
#[inline]
fn encode(&self) -> Vec<u8> {
self.to_le_bytes().to_vec()
}
}
impl SqlParam for bool {
#[inline]
fn encode(&self) -> Vec<u8> {
vec![u8::from(*self)]
}
}
impl SqlParam for &str {
#[inline]
fn encode(&self) -> Vec<u8> {
self.as_bytes().to_vec()
}
}
impl SqlParam for String {
#[inline]
fn encode(&self) -> Vec<u8> {
self.as_bytes().to_vec()
}
}
impl SqlParam for &String {
#[inline]
fn encode(&self) -> Vec<u8> {
self.as_bytes().to_vec()
}
}
impl SqlParam for Vec<u8> {
#[inline]
fn encode(&self) -> Vec<u8> {
self.clone()
}
}
impl SqlParam for &[u8] {
#[inline]
fn encode(&self) -> Vec<u8> {
self.to_vec()
}
}
#[macro_export]
macro_rules! params {
() => {
&[] as &[Option<Vec<u8>>]
};
($($val:expr),+ $(,)?) => {{
use $crate::client::prepare::SqlParam;
vec![$(Some($val.encode())),+]
}};
}
static STATEMENT_COUNTER: AtomicU64 = AtomicU64::new(0);
fn generate_statement_name() -> String {
let id = STATEMENT_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("__hyper_stmt_{id}")
}
#[derive(Debug)]
pub struct PreparedStatement {
name: String,
query: String,
param_types: Vec<Oid>,
columns: Vec<Column>,
}
#[derive(Debug)]
pub struct OwnedPreparedStatement {
statement: PreparedStatement,
connection: Weak<Mutex<RawConnection<SyncStream>>>,
}
impl OwnedPreparedStatement {
pub(crate) fn new(
statement: PreparedStatement,
connection: &Arc<Mutex<RawConnection<SyncStream>>>,
) -> Self {
OwnedPreparedStatement {
statement,
connection: Arc::downgrade(connection),
}
}
#[must_use]
pub fn name(&self) -> &str {
self.statement.name()
}
#[must_use]
pub fn query(&self) -> &str {
self.statement.query()
}
#[must_use]
pub fn param_types(&self) -> &[Oid] {
self.statement.param_types()
}
#[must_use]
pub fn param_count(&self) -> usize {
self.statement.param_count()
}
#[must_use]
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
#[must_use]
pub fn column_count(&self) -> usize {
self.statement.column_count()
}
#[must_use]
pub fn statement(&self) -> &PreparedStatement {
&self.statement
}
pub fn close(self) -> Result<()> {
if let Some(conn) = self.connection.upgrade() {
close_statement(&conn, &self.statement)?;
}
std::mem::forget(self);
Ok(())
}
}
impl Drop for OwnedPreparedStatement {
fn drop(&mut self) {
if let Some(conn) = self.connection.upgrade() {
if let Err(e) = close_statement_internal(&conn, &self.statement) {
warn!(
target: "hyperdb_api",
statement_name = %self.statement.name,
error = %e,
"failed-to-close-prepared-statement-during-drop"
);
}
}
}
}
impl PreparedStatement {
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn query(&self) -> &str {
&self.query
}
#[must_use]
pub fn param_types(&self) -> &[Oid] {
&self.param_types
}
#[must_use]
pub fn param_count(&self) -> usize {
self.param_types.len()
}
#[must_use]
pub fn columns(&self) -> &[Column] {
&self.columns
}
#[must_use]
pub fn column_count(&self) -> usize {
self.columns.len()
}
}
pub fn prepare(
connection: &Arc<Mutex<RawConnection<SyncStream>>>,
query: &str,
param_types: &[Oid],
) -> Result<PreparedStatement> {
let name = generate_statement_name();
let mut conn = connection
.lock()
.map_err(|_| Error::connection("connection mutex poisoned"))?;
frontend::parse(&name, query, param_types, conn.write_buf())?;
frontend::describe(b'S', &name, conn.write_buf())?;
frontend::sync(conn.write_buf());
conn.flush()?;
let mut parsed_params = Vec::new();
let mut parsed_columns = Vec::new();
loop {
let msg = conn.read_message()?;
match msg {
Message::ParseComplete => {
}
Message::ParameterDescription(desc) => {
for oid in desc.parameters().filter_map(|r| {
r.map_err(|e| trace!(target: "hyperdb_api_core::client", error = %e, "dropped error parsing parameter OID")).ok()
}) {
parsed_params.push(oid);
}
}
Message::RowDescription(desc) => {
for f in desc.fields().filter_map(|r| {
r.map_err(|e| trace!(target: "hyperdb_api_core::client", error = %e, "dropped error parsing row description field")).ok()
}) {
parsed_columns.push(Column::new(
f.name().to_string(),
f.type_oid(),
f.type_modifier(),
super::statement::ColumnFormat::from_code(f.format()),
));
}
}
Message::NoData => {
}
Message::ReadyForQuery(_) => {
break;
}
Message::ErrorResponse(body) => {
return Err(conn.consume_error(&body));
}
_ => {}
}
}
Ok(PreparedStatement {
name,
query: query.to_string(),
param_types: parsed_params,
columns: parsed_columns,
})
}
pub fn execute_prepared(
connection: &Arc<Mutex<RawConnection<SyncStream>>>,
statement: &PreparedStatement,
params: &[Option<&[u8]>],
) -> Result<Vec<Row>> {
let mut conn = connection
.lock()
.map_err(|_| Error::connection("connection mutex poisoned"))?;
let param_formats: Vec<i16> = vec![1; params.len()]; let result_formats: Vec<i16> = vec![1; statement.columns.len()];
frontend::bind(
"", &statement.name,
¶m_formats,
params,
&result_formats,
conn.write_buf(),
)?;
frontend::execute("", 0, conn.write_buf())?;
frontend::sync(conn.write_buf());
conn.flush()?;
let mut rows = Vec::new();
let columns = Arc::new(statement.columns.clone());
loop {
let msg = conn.read_message()?;
match msg {
Message::BindComplete => {
}
Message::DataRow(data) => {
rows.push(Row::new(Arc::clone(&columns), data)?);
}
Message::CommandComplete(_) => {
}
Message::EmptyQueryResponse => {
}
Message::ReadyForQuery(_) => {
break;
}
Message::ErrorResponse(body) => {
return Err(conn.consume_error(&body));
}
_ => {}
}
}
Ok(rows)
}
pub fn execute_prepared_no_result(
connection: &Arc<Mutex<RawConnection<SyncStream>>>,
statement: &PreparedStatement,
params: &[Option<&[u8]>],
) -> Result<u64> {
let mut conn = connection
.lock()
.map_err(|_| Error::connection("connection mutex poisoned"))?;
let param_formats: Vec<i16> = vec![1; params.len()];
let result_formats: Vec<i16> = vec![];
frontend::bind(
"",
&statement.name,
¶m_formats,
params,
&result_formats,
conn.write_buf(),
)?;
frontend::execute("", 0, conn.write_buf())?;
frontend::sync(conn.write_buf());
conn.flush()?;
let mut affected_rows = 0u64;
loop {
let msg = conn.read_message()?;
match msg {
Message::BindComplete => {}
Message::CommandComplete(body) => {
if let Ok(tag) = body.tag() {
affected_rows = parse_affected_rows(tag);
}
}
Message::EmptyQueryResponse => {}
Message::ReadyForQuery(_) => {
break;
}
Message::ErrorResponse(body) => {
return Err(conn.consume_error(&body));
}
_ => {}
}
}
Ok(affected_rows)
}
pub fn close_statement(
connection: &Arc<Mutex<RawConnection<SyncStream>>>,
statement: &PreparedStatement,
) -> Result<()> {
close_statement_internal(connection, statement)
}
fn close_statement_internal(
connection: &Arc<Mutex<RawConnection<SyncStream>>>,
statement: &PreparedStatement,
) -> Result<()> {
let mut conn = connection
.lock()
.map_err(|_| Error::connection("connection mutex poisoned"))?;
frontend::close(b'S', &statement.name, conn.write_buf())?;
frontend::sync(conn.write_buf());
conn.flush()?;
loop {
let msg = conn.read_message()?;
match msg {
Message::CloseComplete => {}
Message::ReadyForQuery(_) => {
break;
}
Message::ErrorResponse(body) => {
return Err(conn.consume_error(&body));
}
_ => {}
}
}
Ok(())
}
pub fn prepare_owned(
connection: &Arc<Mutex<RawConnection<SyncStream>>>,
query: &str,
param_types: &[Oid],
) -> Result<OwnedPreparedStatement> {
let statement = prepare(connection, query, param_types)?;
Ok(OwnedPreparedStatement::new(statement, connection))
}
fn parse_affected_rows(tag: &str) -> u64 {
let parts: Vec<&str> = tag.split_whitespace().collect();
match parts.first() {
Some(&"INSERT") => parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0),
Some(&"UPDATE" | &"DELETE" | &"SELECT" | &"COPY") => {
parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0)
}
_ => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sql_param_i16() {
assert_eq!(0_i16.encode(), vec![0, 0]);
assert_eq!(1_i16.encode(), vec![1, 0]);
assert_eq!((-1_i16).encode(), vec![255, 255]);
}
#[test]
fn test_sql_param_i32() {
assert_eq!(0_i32.encode(), vec![0, 0, 0, 0]);
assert_eq!(1_i32.encode(), vec![1, 0, 0, 0]);
assert_eq!((-1_i32).encode(), vec![255, 255, 255, 255]);
assert_eq!(256_i32.encode(), vec![0, 1, 0, 0]);
}
#[test]
fn test_sql_param_i64() {
assert_eq!(0_i64.encode(), vec![0, 0, 0, 0, 0, 0, 0, 0]);
assert_eq!(1_i64.encode(), vec![1, 0, 0, 0, 0, 0, 0, 0]);
assert_eq!(
(-1_i64).encode(),
vec![255, 255, 255, 255, 255, 255, 255, 255]
);
}
#[test]
#[expect(
clippy::float_cmp,
reason = "1.5 is exactly representable; encode/decode must round-trip bit-for-bit"
)]
fn test_sql_param_f32() {
let encoded = 1.5_f32.encode();
assert_eq!(encoded.len(), 4);
let decoded = f32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
assert_eq!(decoded, 1.5);
}
#[test]
#[expect(
clippy::float_cmp,
reason = "1.5 is exactly representable; encode/decode must round-trip bit-for-bit"
)]
fn test_sql_param_f64() {
let encoded = 1.5_f64.encode();
assert_eq!(encoded.len(), 8);
let decoded = f64::from_le_bytes([
encoded[0], encoded[1], encoded[2], encoded[3], encoded[4], encoded[5], encoded[6],
encoded[7],
]);
assert_eq!(decoded, 1.5);
}
#[test]
fn test_sql_param_bool() {
assert_eq!(true.encode(), vec![1]);
assert_eq!(false.encode(), vec![0]);
}
#[test]
fn test_sql_param_str() {
assert_eq!("hello".encode(), b"hello".to_vec());
assert_eq!("".encode(), Vec::<u8>::new());
assert_eq!("héllo".encode(), "héllo".as_bytes().to_vec());
}
#[test]
fn test_sql_param_string() {
let s = String::from("hello");
assert_eq!(s.encode(), b"hello".to_vec());
assert_eq!(s.encode(), b"hello".to_vec());
}
#[test]
fn test_sql_param_bytes() {
let bytes: Vec<u8> = vec![1, 2, 3, 4];
assert_eq!(bytes.encode(), vec![1, 2, 3, 4]);
assert_eq!(bytes.as_slice().encode(), vec![1, 2, 3, 4]);
}
#[test]
fn test_params_macro_empty() {
let p = params![];
assert!(p.is_empty());
}
#[test]
fn test_params_macro_single() {
let p = params![42_i32];
assert_eq!(p.len(), 1);
assert_eq!(p[0], Some(vec![42, 0, 0, 0]));
}
#[test]
fn test_params_macro_multiple() {
let p = params![42_i32, "hello", true];
assert_eq!(p.len(), 3);
assert_eq!(p[0], Some(vec![42, 0, 0, 0]));
assert_eq!(p[1], Some(b"hello".to_vec()));
assert_eq!(p[2], Some(vec![1]));
}
}