use std::fmt::{self, Write};
use serde::Deserialize;
use zerocopy::TryFromBytes;
use crate::{SdkError, wit};
use super::{
ConnectionLike,
types::{DatabaseType, DatabaseValue},
};
#[derive(Clone, Debug)]
pub struct Query {
pub(crate) query: String,
pub(crate) values: Vec<wit::PgBoundValue>,
pub(crate) value_tree: wit::PgValueTree,
}
impl fmt::Display for Query {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.query)
}
}
impl Query {
pub fn builder() -> QueryBuilder {
QueryBuilder::default()
}
pub fn execute<'a>(self, connection: impl Into<ConnectionLike<'a>>) -> Result<u64, SdkError> {
connection
.into()
.execute(&self.query, (self.values.as_slice(), &self.value_tree))
}
pub fn fetch<'a>(
self,
connection: impl Into<ConnectionLike<'a>>,
) -> Result<impl Iterator<Item = ColumnIterator>, SdkError> {
let rows = connection
.into()
.query(&self.query, (self.values.as_slice(), &self.value_tree))?;
let rows = rows.into_iter().map(|row| ColumnIterator {
position: 0,
length: row.len() as usize,
row,
});
Ok(rows)
}
}
pub struct ColumnIterator {
position: usize,
length: usize,
row: wit::PgRow,
}
impl Iterator for ColumnIterator {
type Item = Result<RowValue, SdkError>;
fn next(&mut self) -> Option<Self::Item> {
if self.position < self.length {
let value = match self.row.as_bytes(self.position as u64) {
Ok(value) => value,
Err(err) => return Some(Err(SdkError::from(err))),
};
self.position += 1;
Some(Ok(RowValue { value }))
} else {
None
}
}
}
pub struct RowValue {
value: Option<Vec<u8>>,
}
impl RowValue {
pub fn bytes(&self) -> Option<&[u8]> {
self.value.as_deref()
}
pub fn into_bytes(self) -> Option<Vec<u8>> {
self.value
}
pub fn as_str(&self) -> Result<Option<&str>, SdkError> {
self.value
.as_deref()
.map(|value| {
std::str::from_utf8(value)
.map_err(|e| SdkError::from(format!("Failed to convert bytes to string: {e}")))
})
.transpose()
}
pub fn as_value<T>(&self) -> Result<Option<T>, SdkError>
where
T: TryFromBytes,
{
self.value
.as_deref()
.map(|value| {
T::try_read_from_bytes(value)
.map_err(|e| SdkError::from(format!("Failed to convert bytes to primitive: {e:?}")))
})
.transpose()
}
pub fn as_json<T>(&self) -> Result<Option<T>, SdkError>
where
T: for<'a> Deserialize<'a>,
{
match self.value {
Some(ref value) => serde_json::from_slice(value).map_err(SdkError::from),
None => Ok(None),
}
}
}
#[derive(Debug, Default)]
pub struct QueryBuilder {
query: String,
values: Vec<wit::PgBoundValue>,
value_tree: wit::PgValueTree,
}
impl QueryBuilder {
pub fn bind(&mut self, value: impl DatabaseType) {
let value = value.into_bound_value(self.value_tree.len() as u64);
self.bind_value(value);
}
pub fn bind_value(&mut self, value: DatabaseValue) {
let DatabaseValue {
value: bound_value,
array_values,
} = value;
let wit::PgBoundValue {
mut value,
type_,
is_array,
} = bound_value;
if let wit::PgValue::Array(items) = &mut value {
let offset = self.value_tree.len() as u64;
items.iter_mut().for_each(|x| *x += offset);
}
self.values.push(wit::PgBoundValue { value, type_, is_array });
if let Some(array_values) = array_values {
self.value_tree.extend(array_values);
}
}
pub fn finalize(self) -> Query {
let query = self.query;
let values = self.values;
let value_tree = self.value_tree;
Query {
query,
values,
value_tree,
}
}
pub fn bound_values(&self) -> usize {
self.values.len()
}
}
impl Write for QueryBuilder {
fn write_str(&mut self, s: &str) -> std::fmt::Result {
self.query.write_str(s)
}
}