use core::fmt;
use core::marker::PhantomData;
use core::mem::ManuallyDrop;
use crate::{Bind, Error, Row, SendStatement};
pub struct TypedStatement<I, O> {
inner: SendStatement,
_marker: PhantomData<(I, O)>,
}
impl<O> TypedStatement<(), O> {
pub fn query(&mut self) -> Result<BoundStatement<'_, (), O>, Error> {
self.bind(())
}
}
impl<I> TypedStatement<I, ()>
where
I: Bind,
{
#[track_caller]
pub fn execute<B>(&mut self, bind: B) -> Result<(), Error>
where
B: Bind,
{
const {
assert!(
I::COUNT == B::COUNT,
"unexpected bind parameter count for statement",
);
}
self.inner.reset()?;
self.inner.bind(bind)?;
while !self.inner.step()?.is_done() {}
Ok(())
}
}
impl<I, O> TypedStatement<I, O>
where
I: Bind,
{
#[track_caller]
pub fn bind<B>(&mut self, bind: B) -> Result<BoundStatement<'_, I, O>, Error>
where
B: Bind,
{
const {
assert!(
I::COUNT == B::COUNT,
"unexpected bind parameter count for statement",
);
}
self.inner.reset()?;
self.inner.bind(bind)?;
Ok(BoundStatement {
stmt: &mut self.inner,
_marker: PhantomData,
})
}
}
#[derive(Debug)]
enum TryFromSendStatementErrorKind {
BindParameterCount { expected: i32, actual: i32 },
ColumnCount { expected: i32, actual: i32 },
}
pub struct TryFromSendStatementError {
kind: TryFromSendStatementErrorKind,
}
impl fmt::Debug for TryFromSendStatementError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.kind.fmt(f)
}
}
impl fmt::Display for TryFromSendStatementError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
TryFromSendStatementErrorKind::BindParameterCount { expected, actual } => write!(
f,
"unexpected bind parameter count for statement: expected {expected}, got {actual}",
),
TryFromSendStatementErrorKind::ColumnCount { expected, actual } => write!(
f,
"unexpected column count for statement: expected {expected}, got {actual}",
),
}
}
}
impl core::error::Error for TryFromSendStatementError {}
impl<I, O> TryFrom<SendStatement> for TypedStatement<I, O>
where
I: Bind,
O: for<'stmt> Row<'stmt>,
{
type Error = TryFromSendStatementError;
fn try_from(inner: SendStatement) -> Result<Self, TryFromSendStatementError> {
if inner.bind_parameter_count() != I::COUNT as i32 {
return Err(TryFromSendStatementError {
kind: TryFromSendStatementErrorKind::BindParameterCount {
expected: I::COUNT as i32,
actual: inner.bind_parameter_count(),
},
});
}
if inner.column_count() != O::COUNT as i32 {
return Err(TryFromSendStatementError {
kind: TryFromSendStatementErrorKind::ColumnCount {
expected: O::COUNT as i32,
actual: inner.column_count(),
},
});
}
Ok(Self {
inner,
_marker: PhantomData,
})
}
}
pub struct BoundStatement<'stmt, I, O> {
stmt: &'stmt mut SendStatement,
_marker: PhantomData<(I, O)>,
}
impl<I, O> BoundStatement<'_, I, O> {
pub fn first(self) -> Result<Option<O>, Error>
where
O: for<'stmt> Row<'stmt>,
{
let value = self.stmt.next()?;
let mut this = ManuallyDrop::new(self);
this.stmt.reset()?;
Ok(value)
}
#[inline]
pub fn next(&mut self) -> Result<Option<O>, Error>
where
O: for<'stmt> Row<'stmt>,
{
self.stmt.next()
}
pub fn reset(self) -> Result<(), Error> {
let mut this = ManuallyDrop::new(self);
this.stmt.reset()?;
Ok(())
}
}
impl<I, O> Drop for BoundStatement<'_, I, O> {
fn drop(&mut self) {
let _ = self.stmt.reset();
}
}