use std::sync::Arc;
use hyperdb_api_core::client::AsyncPreparedStatement as LowLevelAsyncPreparedStatement;
use hyperdb_api_core::types::Oid;
use crate::async_connection::AsyncConnection;
use crate::async_result::AsyncRowset;
use crate::async_transport::AsyncTransport;
use crate::error::{Error, Result};
use crate::params::ToSqlParam;
use crate::result::{ResultColumn, ResultSchema, Row, RowValue};
#[derive(Debug)]
pub struct AsyncPreparedStatement<'conn> {
connection: &'conn AsyncConnection,
inner: LowLevelAsyncPreparedStatement,
schema: Arc<ResultSchema>,
}
impl<'conn> AsyncPreparedStatement<'conn> {
#[expect(
clippy::unnecessary_wraps,
reason = "signature retained for API symmetry / future fallibility; returning Result/Option keeps callers from breaking when the function later grows failure cases"
)]
pub(crate) fn new(
connection: &'conn AsyncConnection,
inner: LowLevelAsyncPreparedStatement,
) -> Result<Self> {
let schema = build_schema_from_columns(inner.columns());
Ok(Self {
connection,
inner,
schema: Arc::new(schema),
})
}
#[must_use]
pub fn param_count(&self) -> usize {
self.inner.param_count()
}
#[must_use]
pub fn param_types(&self) -> &[Oid] {
self.inner.param_types()
}
#[must_use]
pub fn schema(&self) -> &ResultSchema {
&self.schema
}
#[must_use]
pub fn sql(&self) -> &str {
self.inner.query()
}
pub async fn query(&self, params: &[&dyn ToSqlParam]) -> Result<AsyncRowset<'conn>> {
let encoded = encode_params(params);
let client = async_tcp_client(self.connection)?;
let stream = client
.execute_prepared_streaming(
&self.inner,
encoded,
crate::result::DEFAULT_BINARY_CHUNK_SIZE,
)
.await?;
Ok(AsyncRowset::from_prepared(stream))
}
pub async fn execute(&self, params: &[&dyn ToSqlParam]) -> Result<u64> {
let encoded = encode_params(params);
let client = async_tcp_client(self.connection)?;
Ok(client
.execute_prepared_no_result(&self.inner, encoded)
.await?)
}
pub async fn fetch_one(&self, params: &[&dyn ToSqlParam]) -> Result<Row> {
self.query(params).await?.require_first_row().await
}
pub async fn fetch_optional(&self, params: &[&dyn ToSqlParam]) -> Result<Option<Row>> {
self.query(params).await?.first_row().await
}
pub async fn fetch_all(&self, params: &[&dyn ToSqlParam]) -> Result<Vec<Row>> {
self.query(params).await?.collect_rows().await
}
pub async fn fetch_scalar<T: RowValue>(&self, params: &[&dyn ToSqlParam]) -> Result<T> {
self.query(params).await?.require_scalar().await
}
pub async fn fetch_optional_scalar<T: RowValue>(
&self,
params: &[&dyn ToSqlParam],
) -> Result<Option<T>> {
self.query(params).await?.scalar().await
}
}
pub(crate) fn encode_params(params: &[&dyn ToSqlParam]) -> Vec<Option<Vec<u8>>> {
params.iter().map(|p| p.encode_param()).collect()
}
#[derive(Debug)]
pub struct AsyncPreparedStatementOwned {
connection: Arc<AsyncConnection>,
inner: LowLevelAsyncPreparedStatement,
schema: Arc<ResultSchema>,
}
impl AsyncPreparedStatementOwned {
#[expect(
clippy::unnecessary_wraps,
reason = "signature retained for API symmetry / future fallibility; returning Result/Option keeps callers from breaking when the function later grows failure cases"
)]
pub(crate) fn new(
connection: Arc<AsyncConnection>,
inner: LowLevelAsyncPreparedStatement,
) -> Result<Self> {
let schema = build_schema_from_columns(inner.columns());
Ok(Self {
connection,
inner,
schema: Arc::new(schema),
})
}
#[must_use]
pub fn param_count(&self) -> usize {
self.inner.param_count()
}
#[must_use]
pub fn param_types(&self) -> &[Oid] {
self.inner.param_types()
}
#[must_use]
pub fn schema(&self) -> &ResultSchema {
&self.schema
}
#[must_use]
pub fn sql(&self) -> &str {
self.inner.query()
}
pub async fn fetch_all(&self, params: &[&dyn ToSqlParam]) -> Result<Vec<Row>> {
let encoded = encode_params(params);
let client = async_tcp_client_arc(&self.connection)?;
let stream = client
.execute_prepared_streaming(
&self.inner,
encoded,
crate::result::DEFAULT_BINARY_CHUNK_SIZE,
)
.await?;
let rowset = AsyncRowset::from_prepared(stream);
rowset.collect_rows().await
}
pub async fn execute(&self, params: &[&dyn ToSqlParam]) -> Result<u64> {
let encoded = encode_params(params);
let client = async_tcp_client_arc(&self.connection)?;
Ok(client
.execute_prepared_no_result(&self.inner, encoded)
.await?)
}
pub async fn fetch_one(&self, params: &[&dyn ToSqlParam]) -> Result<Row> {
self.fetch_all(params)
.await?
.into_iter()
.next()
.ok_or_else(|| crate::error::Error::new("Query returned no rows"))
}
pub async fn fetch_optional(&self, params: &[&dyn ToSqlParam]) -> Result<Option<Row>> {
Ok(self.fetch_all(params).await?.into_iter().next())
}
pub async fn fetch_scalar<T: RowValue>(&self, params: &[&dyn ToSqlParam]) -> Result<T> {
let row = self.fetch_one(params).await?;
row.get::<T>(0)
.ok_or_else(|| crate::error::Error::new("Scalar query returned NULL"))
}
pub async fn fetch_optional_scalar<T: RowValue>(
&self,
params: &[&dyn ToSqlParam],
) -> Result<Option<T>> {
Ok(self
.fetch_optional(params)
.await?
.and_then(|r| r.get::<T>(0)))
}
pub fn close(self) {
drop(self);
}
}
fn async_tcp_client_arc(
connection: &Arc<AsyncConnection>,
) -> Result<&hyperdb_api_core::client::AsyncClient> {
match connection.transport() {
AsyncTransport::Tcp(tcp) => Ok(&tcp.client),
AsyncTransport::Grpc(_) => Err(Error::new(
"prepared statements are not supported over gRPC transport",
)),
}
}
pub(crate) fn async_tcp_client(
connection: &AsyncConnection,
) -> Result<&hyperdb_api_core::client::AsyncClient> {
match connection.transport() {
AsyncTransport::Tcp(tcp) => Ok(&tcp.client),
AsyncTransport::Grpc(_) => Err(Error::new(
"prepared statements are not supported over gRPC transport",
)),
}
}
fn build_schema_from_columns(cols: &[hyperdb_api_core::client::Column]) -> ResultSchema {
let columns = cols
.iter()
.enumerate()
.map(|(idx, col)| {
let sql_type = hyperdb_api_core::types::SqlType::from_oid_and_modifier(
col.type_oid().0,
col.type_modifier(),
);
ResultColumn::new(col.name(), sql_type, idx)
})
.collect();
ResultSchema::from_columns(columns)
}