use std::future::Future;
use crate::data::{Date, DateTime, DateTimeTz, EncodeTarget, Encoder, Time};
use crate::rows::{collect_decoded_rows, decode_one_rows, decode_optional_rows};
use crate::{Column, Driver, Encode, ExecResult, FromRow, ParamValue, Result, RowRef};
#[derive(Debug, Clone, Copy)]
pub enum ParamRef<'a> {
Null,
I64(i64),
U64(u64),
F64(f64),
Date(&'a Date),
Time(&'a Time),
DateTime(&'a DateTime),
DateTimeTz(&'a DateTimeTz),
Uuid(&'a [u8; 16]),
Str(&'a str),
Bytes(&'a [u8]),
}
pub trait ParamSource {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn value_at(&self, index: usize) -> ParamRef<'_>;
}
pub trait PreparedStatement {
type Rows<'a>: RowStream + 'a
where
Self: 'a;
fn execute_source<P>(
&mut self,
params: &P,
) -> impl Future<Output = Result<Self::Rows<'_>>> + Send
where
P: ParamSource + Sync + ?Sized;
fn exec_source<P>(&mut self, params: &P) -> impl Future<Output = Result<ExecResult>> + Send
where
P: ParamSource + Sync + ?Sized;
fn bind<T>(&mut self, value: T) -> BoundStatement<'_, Self>
where
Self: Sized,
T: Encode,
{
BoundStatement::new(self).bind(value)
}
}
pub trait RowStream {
fn columns(&self) -> &[Column];
fn next(&mut self) -> impl Future<Output = Result<Option<RowRef<'_>>>> + Send;
}
pub trait Executor {
type Rows<'a>: RowStream + 'a
where
Self: 'a;
type Statement<'a>: PreparedStatement<Rows<'a> = Self::Rows<'a>> + 'a
where
Self: 'a;
fn driver(&self) -> Driver;
fn query(&mut self, sql: &str) -> impl Future<Output = Result<Self::Rows<'_>>> + Send;
fn query_prepared_source<P>(
&mut self,
sql: &str,
params: &P,
) -> impl Future<Output = Result<Self::Rows<'_>>> + Send
where
P: ParamSource + Sync + ?Sized;
fn prepare(&mut self, sql: &str) -> impl Future<Output = Result<Self::Statement<'_>>> + Send;
}
pub struct Query<'a> {
sql: &'a str,
params: Params,
}
#[derive(Debug, Clone)]
enum ParamSlot {
Null,
I64(i64),
U64(u64),
F64(f64),
Date(Date),
Time(Time),
DateTime(DateTime),
DateTimeTz(DateTimeTz),
Uuid([u8; 16]),
Str { start: usize, len: usize },
Bytes { start: usize, len: usize },
}
pub struct Prepare<'a> {
sql: &'a str,
}
#[derive(Debug, Clone, Default)]
pub struct Params {
params: Vec<ParamSlot>,
arena: Vec<u8>,
}
pub struct BoundStatement<'s, S: PreparedStatement + ?Sized> {
stmt: &'s mut S,
params: Params,
}
pub fn query(sql: &str) -> Query<'_> {
Query {
sql,
params: Params::new(),
}
}
pub fn prepare(sql: &str) -> Prepare<'_> {
Prepare { sql }
}
impl<'a> Query<'a> {
pub fn bind<T>(mut self, value: T) -> Self
where
T: Encode,
{
self.params.push(value);
self
}
pub async fn fetch<E: Executor>(self, exec: &mut E) -> Result<E::Rows<'_>> {
if self.params.is_empty() {
exec.query(self.sql).await
} else {
exec.query_prepared_source(self.sql, &self).await
}
}
pub async fn one<T>(self, mut exec: impl Executor) -> Result<T>
where
T: FromRow,
{
let mut rows = if self.params.is_empty() {
exec.query(self.sql).await?
} else {
exec.query_prepared_source(self.sql, &self).await?
};
decode_one_rows(&mut rows).await
}
pub async fn optional<T>(self, mut exec: impl Executor) -> Result<Option<T>>
where
T: FromRow,
{
let mut rows = if self.params.is_empty() {
exec.query(self.sql).await?
} else {
exec.query_prepared_source(self.sql, &self).await?
};
decode_optional_rows(&mut rows).await
}
pub async fn all<T>(self, mut exec: impl Executor) -> Result<Vec<T>>
where
T: FromRow,
{
let rows = if self.params.is_empty() {
exec.query(self.sql).await?
} else {
exec.query_prepared_source(self.sql, &self).await?
};
collect_decoded_rows(rows).await
}
pub async fn execute(self, mut exec: impl Executor) -> Result<ExecResult> {
let mut stmt = exec.prepare(self.sql).await?;
stmt.exec_source(&self).await
}
}
impl ParamSource for Query<'_> {
fn len(&self) -> usize {
self.params.len()
}
fn value_at(&self, index: usize) -> ParamRef<'_> {
self.params.value_at(index)
}
}
impl Params {
pub fn new() -> Self {
Self::default()
}
pub fn bind<T>(mut self, value: T) -> Self
where
T: Encode,
{
self.push(value);
self
}
pub fn is_empty(&self) -> bool {
self.params.is_empty()
}
fn len(&self) -> usize {
self.params.len()
}
fn push<T>(&mut self, value: T)
where
T: Encode,
{
let before = self.params.len();
let mut target = ParamsEncoder { params: self };
let out = Encoder::new(&mut target);
value.encode(out);
let written = target.params.len() - before;
match written {
1 => {}
0 => {
panic!(
"Encode implementations must write exactly one SQL parameter, but wrote none"
)
}
_ => unreachable!("ownership-based Encoder should prevent multiple parameter writes"),
}
}
fn push_str(&mut self, value: &str) {
let start = self.arena.len();
self.arena.extend_from_slice(value.as_bytes());
self.params.push(ParamSlot::Str {
start,
len: value.len(),
});
}
fn push_bytes(&mut self, value: &[u8]) {
let start = self.arena.len();
self.arena.extend_from_slice(value);
self.params.push(ParamSlot::Bytes {
start,
len: value.len(),
});
}
}
impl ParamSource for Params {
fn len(&self) -> usize {
self.len()
}
fn value_at(&self, index: usize) -> ParamRef<'_> {
match &self.params[index] {
ParamSlot::Null => ParamRef::Null,
ParamSlot::I64(value) => ParamRef::I64(*value),
ParamSlot::U64(value) => ParamRef::U64(*value),
ParamSlot::F64(value) => ParamRef::F64(*value),
ParamSlot::Date(value) => ParamRef::Date(value),
ParamSlot::Time(value) => ParamRef::Time(value),
ParamSlot::DateTime(value) => ParamRef::DateTime(value),
ParamSlot::DateTimeTz(value) => ParamRef::DateTimeTz(value),
ParamSlot::Uuid(value) => ParamRef::Uuid(value),
ParamSlot::Str { start, len } => {
let bytes = &self.arena[*start..*start + *len];
let value = std::str::from_utf8(bytes).expect("params arena stored invalid utf-8");
ParamRef::Str(value)
}
ParamSlot::Bytes { start, len } => ParamRef::Bytes(&self.arena[*start..*start + *len]),
}
}
}
struct ParamsEncoder<'a> {
params: &'a mut Params,
}
impl EncodeTarget for ParamsEncoder<'_> {
fn encode_param(&mut self, value: ParamValue<'_>) {
match value {
ParamValue::Null => self.encode_null(),
ParamValue::I64(value) => self.encode_i64(value),
ParamValue::U64(value) => self.encode_u64(value),
ParamValue::F64(value) => self.encode_f64(value),
ParamValue::Str(value) => self.encode_str(value.as_ref()),
ParamValue::Bytes(value) => self.encode_bytes(value.as_ref()),
}
}
fn encode_null(&mut self) {
self.params.params.push(ParamSlot::Null);
}
fn encode_i64(&mut self, value: i64) {
self.params.params.push(ParamSlot::I64(value));
}
fn encode_u64(&mut self, value: u64) {
self.params.params.push(ParamSlot::U64(value));
}
fn encode_f64(&mut self, value: f64) {
self.params.params.push(ParamSlot::F64(value));
}
fn encode_bool(&mut self, value: bool) {
self.encode_i64(i64::from(value));
}
fn encode_date(&mut self, value: Date) {
self.params.params.push(ParamSlot::Date(value));
}
fn encode_time(&mut self, value: Time) {
self.params.params.push(ParamSlot::Time(value));
}
fn encode_datetime(&mut self, value: DateTime) {
self.params.params.push(ParamSlot::DateTime(value));
}
fn encode_datetime_tz(&mut self, value: DateTimeTz) {
self.params.params.push(ParamSlot::DateTimeTz(value));
}
fn encode_uuid(&mut self, value: [u8; 16]) {
self.params.params.push(ParamSlot::Uuid(value));
}
fn encode_str(&mut self, value: &str) {
self.params.push_str(value);
}
fn encode_string(&mut self, value: String) {
self.params.push_str(&value);
}
fn encode_bytes(&mut self, value: &[u8]) {
self.params.push_bytes(value);
}
fn encode_bytes_owned(&mut self, value: Vec<u8>) {
self.params.push_bytes(&value);
}
}
impl<'s, S> BoundStatement<'s, S>
where
S: PreparedStatement + ?Sized,
{
pub fn new(stmt: &'s mut S) -> Self {
Self {
stmt,
params: Params::new(),
}
}
pub fn bind<T>(mut self, value: T) -> Self
where
T: Encode,
{
self.params.push(value);
self
}
pub async fn execute(self) -> Result<S::Rows<'s>> {
self.stmt.execute_source(&self.params).await
}
pub async fn exec(self) -> Result<ExecResult> {
self.stmt.exec_source(&self.params).await
}
pub async fn one<T>(self) -> Result<T>
where
T: FromRow,
{
let mut rows = self.execute().await?;
decode_one_rows(&mut rows).await
}
pub async fn optional<T>(self) -> Result<Option<T>>
where
T: FromRow,
{
let mut rows = self.execute().await?;
decode_optional_rows(&mut rows).await
}
pub async fn all<T>(self) -> Result<Vec<T>>
where
T: FromRow,
{
collect_decoded_rows(self.execute().await?).await
}
}
impl<'a> Prepare<'a> {
pub async fn run<'e, E: Executor>(self, exec: &'e mut E) -> Result<E::Statement<'e>> {
exec.prepare(self.sql).await
}
}
impl ParamSource for [ParamValue<'_>] {
fn len(&self) -> usize {
<[ParamValue<'_>]>::len(self)
}
fn value_at(&self, index: usize) -> ParamRef<'_> {
param_value_ref(&self[index])
}
}
impl ParamSource for Vec<ParamValue<'_>> {
fn len(&self) -> usize {
self.as_slice().len()
}
fn value_at(&self, index: usize) -> ParamRef<'_> {
param_value_ref(&self[index])
}
}
fn param_value_ref<'a, 'p>(value: &'a ParamValue<'p>) -> ParamRef<'a>
where
'p: 'a,
{
match value {
ParamValue::Null => ParamRef::Null,
ParamValue::I64(value) => ParamRef::I64(*value),
ParamValue::U64(value) => ParamRef::U64(*value),
ParamValue::F64(value) => ParamRef::F64(*value),
ParamValue::Str(value) => ParamRef::Str(value.as_ref()),
ParamValue::Bytes(value) => ParamRef::Bytes(value.as_ref()),
}
}
#[cfg(feature = "mariadb")]
pub(crate) struct MysqlParamSource<'a, P: ?Sized>(pub(crate) &'a P);
#[cfg(feature = "mariadb")]
impl<P> quex_driver::mysql::ParamSource for MysqlParamSource<'_, P>
where
P: ParamSource + ?Sized,
{
fn len(&self) -> usize {
self.0.len()
}
fn value_at(&self, index: usize) -> quex_driver::mysql::ValueRef<'_> {
match self.0.value_at(index) {
ParamRef::Null => quex_driver::mysql::ValueRef::Null,
ParamRef::I64(value) => quex_driver::mysql::ValueRef::I64(value),
ParamRef::U64(value) => quex_driver::mysql::ValueRef::U64(value),
ParamRef::F64(value) => quex_driver::mysql::ValueRef::F64(value),
ParamRef::Date(value) => {
quex_driver::mysql::ValueRef::Date(quex_driver::mysql::DateValue {
year: value.year,
month: value.month,
day: value.day,
})
}
ParamRef::Time(value) => {
quex_driver::mysql::ValueRef::Time(quex_driver::mysql::TimeValue {
hour: value.hour,
minute: value.minute,
second: value.second,
microsecond: value.microsecond,
})
}
ParamRef::DateTime(value) => {
quex_driver::mysql::ValueRef::DateTime(quex_driver::mysql::DateTimeValue {
date: quex_driver::mysql::DateValue {
year: value.date.year,
month: value.date.month,
day: value.date.day,
},
time: quex_driver::mysql::TimeValue {
hour: value.time.hour,
minute: value.time.minute,
second: value.time.second,
microsecond: value.time.microsecond,
},
})
}
ParamRef::DateTimeTz(value) => {
quex_driver::mysql::ValueRef::DateTimeTz(quex_driver::mysql::DateTimeTzValue {
datetime: quex_driver::mysql::DateTimeValue {
date: quex_driver::mysql::DateValue {
year: value.datetime.date.year,
month: value.datetime.date.month,
day: value.datetime.date.day,
},
time: quex_driver::mysql::TimeValue {
hour: value.datetime.time.hour,
minute: value.datetime.time.minute,
second: value.datetime.time.second,
microsecond: value.datetime.time.microsecond,
},
},
offset_seconds: value.offset_seconds,
})
}
ParamRef::Uuid(value) => quex_driver::mysql::ValueRef::Uuid(value),
ParamRef::Str(value) => quex_driver::mysql::ValueRef::String(value),
ParamRef::Bytes(value) => quex_driver::mysql::ValueRef::Bytes(value),
}
}
}
#[cfg(feature = "postgres")]
pub(crate) struct PostgresParamSource<'a, P: ?Sized>(pub(crate) &'a P);
#[cfg(feature = "postgres")]
impl<P> quex_driver::postgres::ParamSource for PostgresParamSource<'_, P>
where
P: ParamSource + ?Sized,
{
fn len(&self) -> usize {
self.0.len()
}
fn value_at(&self, index: usize) -> quex_driver::postgres::ValueRef<'_> {
match self.0.value_at(index) {
ParamRef::Null => quex_driver::postgres::ValueRef::Null,
ParamRef::I64(value) => quex_driver::postgres::ValueRef::I64(value),
ParamRef::U64(value) => quex_driver::postgres::ValueRef::U64(value),
ParamRef::F64(value) => quex_driver::postgres::ValueRef::F64(value),
ParamRef::Date(value) => {
quex_driver::postgres::ValueRef::Date(quex_driver::postgres::DateValue {
year: value.year,
month: value.month,
day: value.day,
})
}
ParamRef::Time(value) => {
quex_driver::postgres::ValueRef::Time(quex_driver::postgres::TimeValue {
hour: value.hour,
minute: value.minute,
second: value.second,
microsecond: value.microsecond,
})
}
ParamRef::DateTime(value) => {
quex_driver::postgres::ValueRef::DateTime(quex_driver::postgres::DateTimeValue {
date: quex_driver::postgres::DateValue {
year: value.date.year,
month: value.date.month,
day: value.date.day,
},
time: quex_driver::postgres::TimeValue {
hour: value.time.hour,
minute: value.time.minute,
second: value.time.second,
microsecond: value.time.microsecond,
},
})
}
ParamRef::DateTimeTz(value) => quex_driver::postgres::ValueRef::DateTimeTz(
quex_driver::postgres::DateTimeTzValue {
datetime: quex_driver::postgres::DateTimeValue {
date: quex_driver::postgres::DateValue {
year: value.datetime.date.year,
month: value.datetime.date.month,
day: value.datetime.date.day,
},
time: quex_driver::postgres::TimeValue {
hour: value.datetime.time.hour,
minute: value.datetime.time.minute,
second: value.datetime.time.second,
microsecond: value.datetime.time.microsecond,
},
},
offset_seconds: value.offset_seconds,
},
),
ParamRef::Uuid(value) => quex_driver::postgres::ValueRef::Uuid(value),
ParamRef::Str(value) => quex_driver::postgres::ValueRef::String(value),
ParamRef::Bytes(value) => quex_driver::postgres::ValueRef::Bytes(value),
}
}
}
impl<T> Encode for &T
where
T: Encode + ?Sized,
{
fn encode(&self, out: Encoder<'_>) {
(*self).encode(out);
}
}