use std::sync::Arc;
use hyperdb_api_core::client::OwnedPreparedStatement;
use hyperdb_api_core::types::Oid;
use crate::connection::Connection;
use crate::error::{Error, Result};
use crate::params::ToSqlParam;
use crate::result::{ResultColumn, ResultSchema, Row, RowValue, Rowset};
use crate::transport::Transport;
#[derive(Debug)]
pub struct PreparedStatement<'conn> {
connection: &'conn Connection,
inner: OwnedPreparedStatement,
schema: Arc<ResultSchema>,
}
impl<'conn> PreparedStatement<'conn> {
#[expect(
clippy::unnecessary_wraps,
reason = "signature retained for API symmetry / future fallibility; returning Result/Option keeps callers from breaking when the function later grows failure cases"
)]
pub(crate) fn new(
connection: &'conn Connection,
inner: OwnedPreparedStatement,
) -> Result<Self> {
let schema = build_schema_from_columns(inner.columns());
Ok(Self {
connection,
inner,
schema: Arc::new(schema),
})
}
#[must_use]
pub fn param_count(&self) -> usize {
self.inner.param_count()
}
#[must_use]
pub fn param_types(&self) -> &[Oid] {
self.inner.param_types()
}
#[must_use]
pub fn schema(&self) -> &ResultSchema {
&self.schema
}
#[must_use]
pub fn sql(&self) -> &str {
self.inner.query()
}
pub fn query(&self, params: &[&dyn ToSqlParam]) -> Result<Rowset<'conn>> {
let encoded = encode_params(params);
let client = tcp_client(self.connection)?;
let stream = client.execute_streaming(
&self.inner,
encoded,
crate::result::DEFAULT_BINARY_CHUNK_SIZE,
)?;
Ok(Rowset::from_prepared(stream))
}
pub fn execute(&self, params: &[&dyn ToSqlParam]) -> Result<u64> {
let encoded = encode_params(params);
let client = tcp_client(self.connection)?;
Ok(client.execute_no_result(&self.inner, encoded)?)
}
pub fn fetch_one(&self, params: &[&dyn ToSqlParam]) -> Result<Row> {
self.query(params)?.require_first_row()
}
pub fn fetch_optional(&self, params: &[&dyn ToSqlParam]) -> Result<Option<Row>> {
self.query(params)?.first_row()
}
pub fn fetch_all(&self, params: &[&dyn ToSqlParam]) -> Result<Vec<Row>> {
self.query(params)?.collect_rows()
}
pub fn fetch_scalar<T: RowValue>(&self, params: &[&dyn ToSqlParam]) -> Result<T> {
self.query(params)?.require_scalar()
}
pub fn fetch_optional_scalar<T: RowValue>(
&self,
params: &[&dyn ToSqlParam],
) -> Result<Option<T>> {
self.query(params)?.scalar()
}
}
pub(crate) fn encode_params(params: &[&dyn ToSqlParam]) -> Vec<Option<Vec<u8>>> {
params.iter().map(|p| p.encode_param()).collect()
}
pub(crate) fn tcp_client(connection: &Connection) -> Result<&hyperdb_api_core::client::Client> {
match connection.transport() {
Transport::Tcp(tcp) => Ok(&tcp.client),
Transport::Grpc(_) => Err(Error::new(
"prepared statements are not supported over gRPC transport",
)),
}
}
fn build_schema_from_columns(cols: &[hyperdb_api_core::client::Column]) -> ResultSchema {
let columns = cols
.iter()
.enumerate()
.map(|(idx, col)| {
let sql_type = hyperdb_api_core::types::SqlType::from_oid_and_modifier(
col.type_oid().0,
col.type_modifier(),
);
ResultColumn::new(col.name(), sql_type, idx)
})
.collect();
ResultSchema::from_columns(columns)
}