use std::{borrow::Cow, fmt::Write, marker::PhantomData};
use crate::{
BigInt, FromRow, LowerCompatible, Qrafting,
builder::{Insert, Insertable},
cte::{Cte, IntoCtes},
dialect::HasDialect,
emitter::{Directive, Emitter},
expression::{Operator, Scalar, prepare_sqlite_glob, random},
insert_into,
instr::RpnInstr,
lower::{Data, Instructions, LowerCtx},
param::Param,
query::{
JoinKind, LowerFilter, LowerFrom, LowerGroupBy, LowerHaving, LowerJoin, LowerOrderBy,
LowerProject,
},
span::TextSource,
ty::TypeMeta,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum SetOperator {
Union,
UnionAll,
}
#[derive(Debug, Clone)]
pub(crate) struct CompoundQuery {
pub operator: SetOperator,
pub right: Box<Query>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum LockClause {
ForUpdate { skip_locked: bool, nowait: bool },
Shared { skip_locked: bool, nowait: bool },
}
#[derive(Debug, Clone)]
pub struct Query {
pub(crate) project: Vec<RpnInstr>,
pub(crate) distinct: bool,
pub(crate) from: Vec<RpnInstr>,
pub(crate) filters: Vec<RpnInstr>,
pub(crate) havings: Vec<RpnInstr>,
pub(crate) group_by: Vec<RpnInstr>,
pub(crate) order_by: Vec<RpnInstr>,
pub(crate) limit: Vec<RpnInstr>,
pub(crate) offset: Vec<RpnInstr>,
pub(crate) lock: Option<LockClause>,
pub(crate) ctes: Vec<Cte>,
pub(crate) compound: Option<CompoundQuery>,
pub(crate) params: Vec<Param>,
pub(crate) data: Vec<u8>,
}
#[derive(Debug)]
pub struct QueryOf<M> {
pub(crate) query: Query,
pub(crate) marker: PhantomData<M>,
}
#[derive(Debug, Clone)]
pub struct LockedQuery {
pub(crate) query: Query,
}
#[derive(Debug)]
pub struct LockedQueryOf<M> {
pub(crate) query: Query,
pub(crate) marker: PhantomData<M>,
}
impl<M> From<QueryOf<M>> for Query {
fn from(value: QueryOf<M>) -> Self {
value.query
}
}
impl From<LockedQuery> for Query {
fn from(value: LockedQuery) -> Self {
value.query
}
}
impl<M> From<LockedQueryOf<M>> for Query {
fn from(value: LockedQueryOf<M>) -> Self {
value.query
}
}
impl<M: FromRow> From<Query> for QueryOf<M> {
fn from(value: Query) -> Self {
Self::new(value)
}
}
impl<M: FromRow> From<LockedQuery> for LockedQueryOf<M> {
fn from(value: LockedQuery) -> Self {
Self::new(value.query)
}
}
impl<M> QueryOf<M>
where
M: FromRow,
{
pub fn new(query: Query) -> Self {
Self {
query,
marker: PhantomData,
}
}
pub fn select<S>(self, s: S) -> Self
where
S: LowerProject,
{
self.query.select(s).typed()
}
pub fn left_join<J>(self, join: J) -> Self
where
J: LowerJoin,
{
self.query.left_join(join).typed()
}
pub fn right_join<J>(self, join: J) -> Self
where
J: LowerJoin,
{
self.query.right_join(join).typed()
}
pub fn inner_join<J>(self, join: J) -> Self
where
J: LowerJoin,
{
self.query.inner_join(join).typed()
}
pub fn filter<E>(self, e: E) -> Self
where
E: LowerFilter,
{
self.query.filter(e).typed()
}
pub fn or_filter<E>(self, e: E) -> Self
where
E: LowerFilter,
{
self.query.or_filter(e).typed()
}
pub fn order_by<T>(self, by: T) -> Self
where
T: LowerOrderBy,
{
self.query.order_by(by).typed()
}
pub fn order_by_random(self) -> Self {
self.order_by(random())
}
pub fn limit<E>(self, e: E) -> Self
where
E: LowerCompatible<BigInt>,
{
self.query.limit(e).typed()
}
pub fn reset_limit(&mut self) {
self.query.reset_limit();
}
pub fn offset<E>(self, e: E) -> Self
where
E: LowerCompatible<BigInt>,
{
self.query.offset(e).typed()
}
pub fn reset_offset(&mut self) {
self.query.reset_offset();
}
pub fn distinct(self) -> Self {
self.query.distinct().typed()
}
pub fn not_distinct(self) -> Self {
self.query.not_distinct().typed()
}
pub fn lock_for_update(self) -> LockedQueryOf<M> {
self.query.lock_for_update().typed()
}
pub fn shared_lock(self) -> LockedQueryOf<M> {
self.query.shared_lock().typed()
}
pub fn with<C>(self, ctes: C) -> Self
where
C: IntoCtes,
{
self.query.with(ctes).typed()
}
pub fn with_recursive<C>(self, ctes: C) -> Self
where
C: IntoCtes,
{
self.query.with_recursive(ctes).typed()
}
pub fn untyped(self) -> Query {
self.query
}
pub fn typed<N>(self) -> QueryOf<N>
where
N: FromRow,
{
self.query.typed()
}
pub fn scalar<T: TypeMeta>(self) -> Scalar<T> {
self.query.scalar::<T>()
}
pub fn union(self, other: impl Into<Query>) -> Self {
self.query.union(other).typed()
}
pub fn union_all(self, other: impl Into<Query>) -> Self {
self.query.union_all(other).typed()
}
#[allow(clippy::wrong_self_convention)]
pub fn to_sql<D: HasDialect>(&mut self) -> String {
self.query.to_sql::<D>()
}
#[allow(clippy::wrong_self_convention)]
pub fn to_debug_sql<D: HasDialect>(&mut self) -> String {
self.query.to_debug_sql::<D>()
}
pub fn into_compiled<D: HasDialect>(self) -> TypedCompiled<M> {
let c = self.query.into_compiled::<D>();
TypedCompiled {
sql: c.sql,
params: c.params,
data: c.data,
marker: PhantomData,
}
}
}
impl<M> QueryOf<M>
where
M: FromRow + Qrafting,
{
pub fn create<V>(self, values: V) -> Insert<M>
where
V: Insertable<M>,
{
insert_into(crate::query::Table::new(M::TABLE)).values(values)
}
}
impl<M> LockedQueryOf<M>
where
M: FromRow,
{
pub fn new(query: Query) -> Self {
Self {
query,
marker: PhantomData,
}
}
pub fn filter<E>(self, e: E) -> Self
where
E: LowerFilter,
{
Self::new(self.query.filter(e))
}
pub fn or_filter<E>(self, e: E) -> Self
where
E: LowerFilter,
{
Self::new(self.query.or_filter(e))
}
pub fn order_by<T>(self, by: T) -> Self
where
T: LowerOrderBy,
{
Self::new(self.query.order_by(by))
}
pub fn order_by_random(self) -> Self {
self.order_by(random())
}
pub fn limit<E>(self, e: E) -> Self
where
E: LowerCompatible<BigInt>,
{
Self::new(self.query.limit(e))
}
pub fn offset<E>(self, e: E) -> Self
where
E: LowerCompatible<BigInt>,
{
Self::new(self.query.offset(e))
}
pub fn distinct(self) -> Self {
Self::new(self.query.distinct())
}
pub fn not_distinct(self) -> Self {
Self::new(self.query.not_distinct())
}
pub fn skip_locked(self) -> Self {
Self::new(self.query.skip_locked())
}
pub fn nowait(self) -> Self {
Self::new(self.query.nowait())
}
pub fn with<C>(self, ctes: C) -> Self
where
C: IntoCtes,
{
Self::new(self.query.with(ctes))
}
pub fn with_recursive<C>(self, ctes: C) -> Self
where
C: IntoCtes,
{
Self::new(self.query.with_recursive(ctes))
}
pub fn untyped(self) -> Query {
self.query
}
pub fn scalar<T: TypeMeta>(self) -> Scalar<T> {
self.query.scalar::<T>()
}
#[allow(clippy::wrong_self_convention)]
pub fn to_sql<D: HasDialect>(&mut self) -> String {
self.query.to_sql::<D>()
}
#[allow(clippy::wrong_self_convention)]
pub fn to_debug_sql<D: HasDialect>(&mut self) -> String {
self.query.to_debug_sql::<D>()
}
pub fn into_compiled<D: HasDialect>(self) -> TypedCompiled<M> {
let c = self.query.into_compiled::<D>();
TypedCompiled {
sql: c.sql,
params: c.params,
data: c.data,
marker: PhantomData,
}
}
}
impl LockedQuery {
pub fn new(query: Query) -> Self {
Self { query }
}
pub fn select<S>(self, s: S) -> Self
where
S: LowerProject,
{
Self::new(self.query.select(s))
}
pub fn left_join<J>(self, join: J) -> Self
where
J: LowerJoin,
{
Self::new(self.query.left_join(join))
}
pub fn right_join<J>(self, join: J) -> Self
where
J: LowerJoin,
{
Self::new(self.query.right_join(join))
}
pub fn inner_join<J>(self, join: J) -> Self
where
J: LowerJoin,
{
Self::new(self.query.inner_join(join))
}
pub fn having<E>(self, e: E) -> Self
where
E: LowerHaving,
{
Self::new(self.query.having(e))
}
pub fn or_having<E>(self, e: E) -> Self
where
E: LowerHaving,
{
Self::new(self.query.or_having(e))
}
pub fn filter<E>(self, e: E) -> Self
where
E: LowerFilter,
{
Self::new(self.query.filter(e))
}
pub fn or_filter<E>(self, e: E) -> Self
where
E: LowerFilter,
{
Self::new(self.query.or_filter(e))
}
pub fn group_by<T>(self, by: T) -> Self
where
T: LowerGroupBy,
{
Self::new(self.query.group_by(by))
}
pub fn order_by<T>(self, by: T) -> Self
where
T: LowerOrderBy,
{
Self::new(self.query.order_by(by))
}
pub fn order_by_random(self) -> Self {
self.order_by(random())
}
pub fn limit<E>(self, e: E) -> Self
where
E: LowerCompatible<BigInt>,
{
Self::new(self.query.limit(e))
}
pub fn offset<E>(self, e: E) -> Self
where
E: LowerCompatible<BigInt>,
{
Self::new(self.query.offset(e))
}
pub fn distinct(self) -> Self {
Self::new(self.query.distinct())
}
pub fn not_distinct(self) -> Self {
Self::new(self.query.not_distinct())
}
pub fn skip_locked(self) -> Self {
Self::new(self.query.skip_locked())
}
pub fn nowait(self) -> Self {
Self::new(self.query.nowait())
}
pub fn with<C>(self, ctes: C) -> Self
where
C: IntoCtes,
{
Self::new(self.query.with(ctes))
}
pub fn with_recursive<C>(self, ctes: C) -> Self
where
C: IntoCtes,
{
Self::new(self.query.with_recursive(ctes))
}
pub fn typed<M>(self) -> LockedQueryOf<M>
where
M: FromRow,
{
LockedQueryOf::new(self.query)
}
pub fn scalar<T: TypeMeta>(self) -> Scalar<T> {
self.query.scalar::<T>()
}
#[allow(clippy::wrong_self_convention)]
pub fn to_sql<D: HasDialect>(&mut self) -> String {
self.query.to_sql::<D>()
}
#[allow(clippy::wrong_self_convention)]
pub fn to_debug_sql<D: HasDialect>(&mut self) -> String {
self.query.to_debug_sql::<D>()
}
pub fn into_compiled<D: HasDialect>(self) -> Compiled {
self.query.into_compiled::<D>()
}
}
impl Query {
pub(crate) fn default() -> Self {
Self {
project: Vec::default(),
filters: Vec::default(),
params: Vec::default(),
data: Vec::default(),
from: Vec::default(),
havings: Vec::default(),
group_by: Vec::default(),
order_by: Vec::default(),
limit: Vec::default(),
distinct: false,
offset: Vec::default(),
lock: None,
ctes: Vec::default(),
compound: None,
}
}
fn merge_query(&mut self, other: &mut Query) {
let data_offset = self.data.len() as u32;
let param_offset = self.params.len() as u32;
self.data.extend_from_slice(&other.data);
offset_query_data(other, data_offset);
shift_query_params(other, param_offset);
self.params.extend(other.params.iter().copied());
}
fn where_clause<E>(mut self, e: E, op: Operator) -> Self
where
E: LowerFilter,
{
let is_first = self.filters.is_empty();
let lhs = self.filters.len();
let mut ctx = LowerCtx {
instrs: &mut self.filters,
params: &mut self.params,
data: &mut self.data,
};
e.lower_filter(&mut ctx);
if !is_first {
self.filters.push_binary(op, lhs, self.filters.len() - lhs);
}
self
}
fn having_clause<E>(mut self, e: E, op: Operator) -> Self
where
E: LowerHaving,
{
let is_first = self.havings.is_empty();
let lhs = self.havings.len();
let mut ctx = LowerCtx {
instrs: &mut self.havings,
params: &mut self.params,
data: &mut self.data,
};
e.lower_having(&mut ctx);
if !is_first {
self.havings.push_binary(op, lhs, self.havings.len() - lhs);
}
self
}
pub fn select<S>(mut self, s: S) -> Self
where
S: LowerProject,
{
self.reset_select();
let mut ctx = LowerCtx {
instrs: &mut self.project,
params: &mut self.params,
data: &mut self.data,
};
s.lower_project(&mut ctx);
self
}
pub fn reset_select(&mut self) {
if !self.project.is_empty() {
self.project.clear();
}
}
pub fn left_join<J>(mut self, join: J) -> Self
where
J: LowerJoin,
{
let mut ctx = LowerCtx {
instrs: &mut self.from,
params: &mut self.params,
data: &mut self.data,
};
join.lower_join(JoinKind::Left, &mut ctx);
self
}
pub fn right_join<J>(mut self, join: J) -> Self
where
J: LowerJoin,
{
let mut ctx = LowerCtx {
instrs: &mut self.from,
params: &mut self.params,
data: &mut self.data,
};
join.lower_join(JoinKind::Right, &mut ctx);
self
}
pub fn inner_join<J>(mut self, join: J) -> Self
where
J: LowerJoin,
{
let mut ctx = LowerCtx {
instrs: &mut self.from,
params: &mut self.params,
data: &mut self.data,
};
join.lower_join(JoinKind::Inner, &mut ctx);
self
}
pub fn having<E>(self, e: E) -> Self
where
E: LowerHaving,
{
self.having_clause(e, Operator::And)
}
pub fn or_having<E>(self, e: E) -> Self
where
E: LowerHaving,
{
self.having_clause(e, Operator::Or)
}
pub fn filter<E>(self, e: E) -> Self
where
E: LowerFilter,
{
self.where_clause(e, Operator::And)
}
pub fn or_filter<E>(self, e: E) -> Self
where
E: LowerFilter,
{
self.where_clause(e, Operator::Or)
}
pub fn from<T>(table: T) -> Self
where
T: LowerFrom,
{
let mut me = Self::default();
let mut ctx = LowerCtx {
instrs: &mut me.from,
params: &mut me.params,
data: &mut me.data,
};
table.lower_from(&mut ctx);
me
}
pub fn union(self, other: impl Into<Query>) -> Self {
self.compound_with(other.into(), SetOperator::Union)
}
pub fn union_all(self, other: impl Into<Query>) -> Self {
self.compound_with(other.into(), SetOperator::UnionAll)
}
fn compound_with(mut self, mut other: Query, operator: SetOperator) -> Self {
self.merge_query(&mut other);
self.compound = Some(CompoundQuery {
operator,
right: Box::new(other),
});
self
}
pub(crate) fn base_table_static(&self) -> Option<&'static str> {
match self.from.first() {
Some(RpnInstr::Table {
span: crate::span::TextSpan(TextSource::StaticText(table)),
}) => Some(*table),
_ => None,
}
}
pub fn group_by<T>(mut self, by: T) -> Self
where
T: LowerGroupBy,
{
let mut ctx = LowerCtx {
instrs: &mut self.group_by,
params: &mut self.params,
data: &mut self.data,
};
by.lower_group_by(&mut ctx);
self
}
pub fn order_by<T>(mut self, by: T) -> Self
where
T: LowerOrderBy,
{
let mut ctx = LowerCtx {
instrs: &mut self.order_by,
params: &mut self.params,
data: &mut self.data,
};
by.lower_order_by(&mut ctx);
self
}
pub fn order_by_random(self) -> Self {
self.order_by(random())
}
pub fn limit<E>(mut self, e: E) -> Self
where
E: LowerCompatible<BigInt>,
{
self.reset_limit();
let mut ctx = LowerCtx {
instrs: &mut self.limit,
params: &mut self.params,
data: &mut self.data,
};
e.lower_compatible(&mut ctx);
self
}
pub fn reset_limit(&mut self) {
if !self.limit.is_empty() {
self.limit.clear();
}
}
pub fn offset<E>(mut self, e: E) -> Self
where
E: LowerCompatible<BigInt>,
{
self.reset_offset();
let mut ctx = LowerCtx {
instrs: &mut self.offset,
params: &mut self.params,
data: &mut self.data,
};
e.lower_compatible(&mut ctx);
self
}
pub fn reset_offset(&mut self) {
if !self.offset.is_empty() {
self.offset.clear();
}
}
pub fn lock_for_update(mut self) -> LockedQuery {
self.lock = Some(LockClause::ForUpdate {
skip_locked: false,
nowait: false,
});
LockedQuery { query: self }
}
pub fn shared_lock(mut self) -> LockedQuery {
self.lock = Some(LockClause::Shared {
skip_locked: false,
nowait: false,
});
LockedQuery { query: self }
}
fn skip_locked(mut self) -> Self {
self.lock = Some(match self.lock.expect("lock clause must be selected") {
LockClause::ForUpdate { nowait, .. } => LockClause::ForUpdate {
skip_locked: true,
nowait,
},
LockClause::Shared { nowait, .. } => LockClause::Shared {
skip_locked: true,
nowait,
},
});
self
}
fn nowait(mut self) -> Self {
self.lock = Some(match self.lock.expect("lock clause must be selected") {
LockClause::ForUpdate { skip_locked, .. } => LockClause::ForUpdate {
skip_locked,
nowait: true,
},
LockClause::Shared { skip_locked, .. } => LockClause::Shared {
skip_locked,
nowait: true,
},
});
self
}
pub fn with<C>(mut self, ctes: C) -> Self
where
C: IntoCtes,
{
for cte in ctes.into_ctes(false) {
self.push_cte(cte);
}
self
}
pub(crate) fn with_many(mut self, ctes: Vec<Cte>) -> Self {
for cte in ctes {
self.push_cte(cte);
}
self
}
pub fn with_recursive<C>(mut self, ctes: C) -> Self
where
C: IntoCtes,
{
for cte in ctes.into_ctes(true) {
self.push_cte(cte);
}
self
}
fn push_cte(&mut self, mut cte: Cte) {
let cte_param_offset = self
.ctes
.iter()
.map(|existing| existing.query.params.len())
.sum::<usize>();
let cte_param_count = cte.query.params.len();
let data_offset = self.data.len() as u32;
self.data.extend_from_slice(&cte.query.data);
offset_query_data(&mut cte.query, data_offset);
shift_query_params(&mut cte.query, cte_param_offset as u32);
shift_main_query_params(self, cte_param_count as u32);
self.params.splice(
cte_param_offset..cte_param_offset,
cte.query.params.iter().copied(),
);
self.ctes.push(cte);
}
pub fn scalar<T: TypeMeta>(self) -> Scalar<T> {
Scalar {
inner: self,
marker: PhantomData,
}
}
pub fn distinct(mut self) -> Self {
self.distinct = true;
self
}
pub fn not_distinct(mut self) -> Self {
self.distinct = false;
self
}
pub fn typed<M>(self) -> QueryOf<M>
where
M: FromRow,
{
QueryOf::new(self)
}
pub fn typed_locked<M>(self) -> LockedQueryOf<M>
where
M: FromRow,
{
LockedQueryOf::new(self)
}
pub(crate) fn debug_params<W: Write>(&self, writer: &mut W) -> std::fmt::Result {
writer.write_str("; params=[")?;
for (i, param) in self.params.iter().enumerate() {
if i > 0 {
writer.write_str(", ")?;
}
match param {
Param::Null => writer.write_str("null")?,
Param::Bool(value) => match value {
Some(value) => write!(writer, "{}", value)?,
None => writer.write_str("null")?,
},
Param::Float(value) => match value {
Some(value) => write!(writer, "{}", value)?,
None => writer.write_str("null")?,
},
Param::Double(value) => match value {
Some(value) => write!(writer, "{}", value)?,
None => writer.write_str("null")?,
},
Param::Int(value) => match value {
Some(value) => write!(writer, "{}", value)?,
None => writer.write_str("null")?,
},
Param::BigInt(value) => match value {
Some(value) => write!(writer, "{}", value)?,
None => writer.write_str("null")?,
},
Param::Text(span) => match span {
Some(span) => write!(writer, "\"{}\"", self.data.text(*span))?,
None => writer.write_str("null")?,
},
Param::Blob(span) => match span {
Some(span) => write!(writer, "\"{:?}\"", self.data.blob(*span))?,
None => writer.write_str("null")?,
},
Param::UInt(value) => match value {
Some(value) => write!(writer, "{}", value)?,
None => writer.write_str("null")?,
},
Param::UBigInt(value) => match value {
Some(value) => write!(writer, "{}", value)?,
None => writer.write_str("null")?,
},
}
}
writer.write_char(']')?;
Ok(())
}
pub fn to_sql<D: HasDialect>(&mut self) -> String {
let mut buf = String::new();
let mut directives = Vec::new();
let mut params = Vec::new();
let mut emitter = Emitter::new(
&mut buf,
&self.data,
D::DIALECT,
&mut directives,
&mut params,
);
emitter.emit_query(self).unwrap();
for directive in directives {
match directive {
Directive::RewriteGlob { id } => {
let maybe_param = self.params.get_mut(id);
if let Some(Param::Text(Some(text_span))) = maybe_param {
let text = self.data.text(*text_span);
let value = prepare_sqlite_glob(text);
if let Cow::Owned(value) = value {
if let TextSource::Text(span) = text_span.0
&& value.len() == text.len()
{
let v = &mut self.data
[span.start as usize..span.start as usize + span.len as usize];
v.copy_from_slice(value.as_bytes());
} else {
let span = self.data.intern_text(&value);
*text_span = span;
}
}
}
}
}
}
rewrite_params(¶ms, &mut self.params);
buf
}
pub fn into_compiled<D: HasDialect>(mut self) -> Compiled {
Compiled {
sql: self.to_sql::<D>(),
params: self.params,
data: self.data,
}
}
pub fn to_debug_sql<D: HasDialect>(&mut self) -> String {
let mut sql = self.to_sql::<D>();
self.debug_params(&mut sql)
.expect("cannot fail with string writer");
sql
}
}
pub fn rewrite_params(indexes: &[usize], params: &mut [Param]) {
if !indexes.is_empty() {
assert!(indexes.len() == params.len());
let tmp = params.to_vec();
for (i, j) in indexes.iter().enumerate() {
params[i] = tmp[*j];
}
}
}
fn shift_instr_params(instrs: &mut [RpnInstr], delta: u32) {
for instr in instrs {
if let RpnInstr::Param { id } = instr {
*id += delta;
}
}
}
fn offset_instr_data(instrs: &mut [RpnInstr], offset: u32) {
for instr in instrs {
match instr {
RpnInstr::Column { span, table } => {
*span = span.offset(offset);
if let Some(table) = table {
*table = table.offset(offset);
}
}
RpnInstr::Table { span }
| RpnInstr::Raw { span, .. }
| RpnInstr::Alias { span, .. } => {
*span = span.offset(offset);
}
_ => {}
}
}
}
fn shift_main_query_params(query: &mut Query, delta: u32) {
shift_instr_params(&mut query.project, delta);
shift_instr_params(&mut query.from, delta);
shift_instr_params(&mut query.filters, delta);
shift_instr_params(&mut query.havings, delta);
shift_instr_params(&mut query.group_by, delta);
shift_instr_params(&mut query.order_by, delta);
shift_instr_params(&mut query.limit, delta);
shift_instr_params(&mut query.offset, delta);
if let Some(compound) = &mut query.compound {
shift_query_params(&mut compound.right, delta);
}
}
fn shift_query_params(query: &mut Query, delta: u32) {
shift_main_query_params(query, delta);
for cte in &mut query.ctes {
shift_query_params(&mut cte.query, delta);
}
}
fn offset_query_data(query: &mut Query, offset: u32) {
offset_instr_data(&mut query.project, offset);
offset_instr_data(&mut query.from, offset);
offset_instr_data(&mut query.filters, offset);
offset_instr_data(&mut query.havings, offset);
offset_instr_data(&mut query.group_by, offset);
offset_instr_data(&mut query.order_by, offset);
offset_instr_data(&mut query.limit, offset);
offset_instr_data(&mut query.offset, offset);
if let Some(compound) = &mut query.compound {
offset_query_data(&mut compound.right, offset);
}
for cte in &mut query.ctes {
offset_query_data(&mut cte.query, offset);
}
}
#[derive(Debug)]
pub struct Compiled {
pub sql: String,
pub params: Vec<Param>,
pub data: Vec<u8>,
}
#[derive(Debug)]
pub struct TypedCompiled<T> {
pub sql: String,
pub params: Vec<Param>,
pub data: Vec<u8>,
pub marker: PhantomData<T>,
}
#[cfg(test)]
mod tests {
use crate::{
BigInt, Bool, Boolean, Float, MySql, Nullable, Numeric, Postgres, Qrafting, Sqlite,
aggregate::count,
alias::Alias,
cte::{CteDefinition, with, with_recursive},
expression::{
Column, EqExt, Expression, In, LikeExt, OrderExt, PredicateExt, TimestampExt,
TimestampRelativeExt, abs, coalesce, current_date, current_time, length, lit, lower,
not, now, null_if, round, upper,
},
query::{Order, Table, select, select_all, star},
tests::{User, id, name, table, username},
};
struct QueuedUser;
impl Qrafting for QueuedUser {
type Schema = ();
type QueryPolicy = crate::DefaultQueryPolicy<Self>;
const FIELD_COUNT: usize = 1;
const TABLE: &'static str = "queued_users";
}
struct NamedUser;
impl Qrafting for NamedUser {
type Schema = ();
type QueryPolicy = crate::DefaultQueryPolicy<Self>;
const FIELD_COUNT: usize = 1;
const TABLE: &'static str = "named_users";
}
struct Chain;
impl Qrafting for Chain {
type Schema = ();
type QueryPolicy = crate::DefaultQueryPolicy<Self>;
const FIELD_COUNT: usize = 1;
const TABLE: &'static str = "chain";
}
fn assert_nullable_bool<E>(_: &E)
where
E: Expression<Type = Nullable<Bool>>,
{
}
fn assert_float<E>(_: &E)
where
E: Expression<Type = Float>,
{
}
fn assert_boolean_expr<E>(_: &E)
where
E: Expression,
E::Type: Boolean,
{
}
fn assert_numeric_expr<E>(_: &E)
where
E: Expression,
E::Type: Numeric,
{
}
fn assert_bigint<E>(_: &E)
where
E: Expression<Type = BigInt>,
{
}
#[test]
fn test_simple_select() {
let stmt = select((id, username))
.from(table)
.filter(id.eq(10))
.order_by_random()
.to_debug_sql::<Sqlite>();
assert_eq!(
stmt,
r#"select "users"."id", "users"."username" from "users" where "users"."id" = ? order by random(); params=[10]"#
);
}
#[test]
fn test_select_like() {
let stmt = select((id, username))
.from(table)
.filter(username.like("hello"))
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"select "users"."id", "users"."username" from "users" where "users"."username"::text ilike $1; params=["hello"]"#
);
}
#[test]
fn test_select_glob() {
let stmt = select_all()
.from(table)
.filter(username.like("%hel*lo%").case_sensitive())
.to_debug_sql::<Sqlite>();
assert_eq!(
stmt,
r#"select * from "users" where "users"."username" glob ?; params=["*hel[*]lo*"]"#
);
}
#[test]
fn test_select_alias() {
let alias = table.alias("u");
let ia = alias.col(id);
let iu = alias.col(username);
let stmt = select((ia, iu))
.from(alias)
.filter(iu.like("hello").case_sensitive())
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"select "u"."id", "u"."username" from "users" as "u" where "u"."username"::text like $1; params=["hello"]"#
);
}
#[test]
fn test_select_boolean_precedence() {
let stmt = select_all()
.from(table)
.filter(id.eq(1).or(id.eq(2).and(id.eq(3))))
.to_debug_sql::<Sqlite>();
assert_eq!(
stmt,
r#"select * from "users" where "users"."id" = ? or "users"."id" = ? and "users"."id" = ?; params=[1, 2, 3]"#
);
}
#[test]
fn test_select_between_arithmetic_precedence() {
let stmt = select(lit::<BigInt>(1) + lit::<BigInt>(2) + lit::<BigInt>(3))
.from(table)
.filter(id.gt(0))
.to_debug_sql::<Sqlite>();
assert_eq!(
stmt,
r#"select 1 + 2 + 3 from "users" where "users"."id" > ?; params=[0]"#
);
}
#[test]
fn test_select_join_limit_offset_and_order() {
let other = table.alias("other");
let stmt = select((id, username))
.from(table)
.inner_join(other.on(id.eq(other.col(id))))
.order_by((id.desc(), username.asc()))
.limit(5_i64)
.offset(10_i64)
.to_debug_sql::<Sqlite>();
assert_eq!(
stmt,
r#"select "users"."id", "users"."username" from "users" inner join "users" as "other" on "users"."id" = "other"."id" order by "users"."id" desc, "users"."username" asc limit ? offset ?; params=[5, 10]"#
);
}
#[test]
fn test_select_distinct_group_by_and_having() {
let stmt = select((username, count(id).alias("total")))
.from(table)
.distinct()
.group_by(username)
.having(true)
.to_debug_sql::<Sqlite>();
assert_eq!(
stmt,
r#"select distinct "users"."username", count("users"."id") as "total" from "users" group by "users"."username" having ?; params=[true]"#
);
}
#[test]
fn test_select_limit_and_offset_can_be_reset() {
let mut query = select_all()
.from(table)
.limit(10)
.offset(20)
.typed::<User>();
query.reset_limit();
query.reset_offset();
let stmt = query.to_sql::<Sqlite>();
assert_eq!(stmt, r#"select * from "users""#);
}
#[test]
fn test_nullable_boolean_logic_types_and_sql() {
let admin: Column<User, Nullable<Bool>> = Column::new("admin");
let moderator: Column<User, Nullable<Bool>> = Column::new("moderator");
let expr = admin.or(moderator).and(not(admin));
assert_nullable_bool(&admin.eq(moderator));
assert_nullable_bool(&admin.or(moderator));
assert_boolean_expr(&expr);
let stmt = select_all()
.from(table)
.filter(expr)
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"select * from "users" where ("users"."admin" or "users"."moderator") and not ("users"."admin"); params=[]"#
);
}
#[test]
fn test_nullable_between_and_in_type_and_sql() {
let score: Column<User, Nullable<BigInt>> = Column::new("score");
let between = score.between(Some(10_i64), Some(20_i64));
let in_list = score.in_([Some(1_i64), Some(2_i64)]);
assert_nullable_bool(&between);
assert_nullable_bool(&in_list);
let stmt = select_all()
.from(table)
.filter(between.and(in_list))
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"select * from "users" where "users"."score" between $1 and $2 and "users"."score" in ($3, $4); params=[10, 20, 1, 2]"#,
);
}
#[test]
fn test_mixed_numeric_math_types_and_sql() {
let score: Column<User, Float> = Column::new("score");
let expr = id + score;
assert_float(&expr);
assert_numeric_expr(&expr);
let stmt = select_all()
.from(table)
.filter(expr.eq(52.5_f32))
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"select * from "users" where "users"."id" + "users"."score" = $1; params=[52.5]"#,
);
}
#[test]
fn test_not_in_emits_not_in() {
let stmt = select_all()
.from(table)
.filter(id.not_in([1_i64, 2_i64]))
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"select * from "users" where "users"."id" not in ($1, $2); params=[1, 2]"#,
);
}
#[test]
fn test_select_lock_for_update_postgres() {
let stmt = select_all()
.from(table)
.lock_for_update()
.to_sql::<Postgres>();
assert_eq!(stmt, r#"select * from "users" for update"#);
}
#[test]
fn test_select_shared_lock_postgres() {
let stmt = select_all().from(table).shared_lock().to_sql::<Postgres>();
assert_eq!(stmt, r#"select * from "users" for share"#);
}
#[test]
fn test_select_shared_lock_mariadb() {
let stmt = select_all().from(table).shared_lock().to_sql::<MySql>();
assert_eq!(stmt, "select * from `users` lock in share mode");
}
#[test]
fn test_select_lock_for_update_skip_locked_postgres() {
let stmt = select_all()
.from(table)
.lock_for_update()
.skip_locked()
.to_sql::<Postgres>();
assert_eq!(stmt, r#"select * from "users" for update skip locked"#);
}
#[test]
fn test_select_lock_for_update_nowait_postgres() {
let stmt = select_all()
.from(table)
.lock_for_update()
.nowait()
.to_sql::<Postgres>();
assert_eq!(stmt, r#"select * from "users" for update nowait"#);
}
#[test]
fn test_select_lock_for_update_skip_locked_nowait_postgres() {
let stmt = select_all()
.from(table)
.lock_for_update()
.skip_locked()
.nowait()
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"select * from "users" for update skip locked nowait"#
);
}
#[test]
fn test_select_shared_lock_nowait_postgres() {
let stmt = select_all()
.from(table)
.shared_lock()
.nowait()
.to_sql::<Postgres>();
assert_eq!(stmt, r#"select * from "users" for share nowait"#);
}
#[test]
fn test_select_shared_lock_skip_locked_postgres() {
let stmt = select_all()
.from(table)
.shared_lock()
.skip_locked()
.to_sql::<Postgres>();
assert_eq!(stmt, r#"select * from "users" for share skip locked"#);
}
#[test]
fn test_select_shared_lock_skip_locked_nowait_postgres() {
let stmt = select_all()
.from(table)
.shared_lock()
.skip_locked()
.nowait()
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"select * from "users" for share skip locked nowait"#
);
}
#[test]
fn test_select_lock_for_update_skip_locked_mariadb() {
let stmt = select_all()
.from(table)
.lock_for_update()
.skip_locked()
.to_sql::<MySql>();
assert_eq!(stmt, "select * from `users` for update skip locked");
}
#[test]
fn test_select_lock_for_update_nowait_mariadb() {
let stmt = select_all()
.from(table)
.lock_for_update()
.nowait()
.to_sql::<MySql>();
assert_eq!(stmt, "select * from `users` for update nowait");
}
#[test]
fn test_select_lock_for_update_skip_locked_nowait_mariadb() {
let stmt = select_all()
.from(table)
.lock_for_update()
.skip_locked()
.nowait()
.to_sql::<MySql>();
assert_eq!(stmt, "select * from `users` for update skip locked nowait");
}
#[test]
#[should_panic(expected = "locking clauses are not supported for sqlite")]
fn test_select_lock_for_update_sqlite_panics() {
let _ = select_all()
.from(table)
.lock_for_update()
.to_sql::<Sqlite>();
}
#[test]
#[should_panic(expected = "locking clauses are not supported for sqlite")]
fn test_select_shared_lock_sqlite_panics() {
let _ = select_all().from(table).shared_lock().to_sql::<Sqlite>();
}
#[test]
#[should_panic(expected = "lock modifiers are not supported with shared_lock() for mariadb")]
fn test_select_shared_lock_skip_locked_mariadb_panics() {
let _ = select_all()
.from(table)
.shared_lock()
.skip_locked()
.to_sql::<MySql>();
}
#[test]
#[should_panic(expected = "lock modifiers are not supported with shared_lock() for mariadb")]
fn test_select_shared_lock_nowait_mariadb_panics() {
let _ = select_all()
.from(table)
.shared_lock()
.nowait()
.to_sql::<MySql>();
}
#[test]
fn test_select_with_single_cte() {
let queued = CteDefinition::<QueuedUser>::new(
"queued_users",
&[],
select(id).from(table).filter(id.eq(10)),
);
let stmt = with(queued)
.select(star())
.from(Table::<QueuedUser>::new("queued_users"))
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"with "queued_users" as (select "users"."id" from "users" where "users"."id" = $1) select * from "queued_users"; params=[10]"#
);
}
#[test]
fn test_select_with_multiple_ctes_and_columns() {
let queued = CteDefinition::<QueuedUser>::new(
"queued_users",
&["id"],
select(id).from(table).filter(id.eq(10)),
);
let named =
CteDefinition::<NamedUser>::new("named_users", &[], select(username).from(table));
let stmt = with((queued, named))
.select(star())
.from((
Table::<QueuedUser>::new("queued_users"),
Table::<NamedUser>::new("named_users"),
))
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"with "queued_users"("id") as (select "users"."id" from "users" where "users"."id" = $1), "named_users" as (select "users"."username" from "users") select * from ("queued_users" cross join "named_users"); params=[10]"#
);
}
#[test]
fn test_select_with_recursive_cte() {
let chain = CteDefinition::<Chain>::new("chain", &["id"], select(id).from(table));
let stmt = with_recursive(chain)
.select(star())
.from(Table::<Chain>::new("chain"))
.to_sql::<Postgres>();
assert_eq!(
stmt,
r#"with recursive "chain"("id") as (select "users"."id" from "users") select * from "chain""#
);
}
#[test]
fn test_select_union_all() {
let stmt = select(id)
.from(table)
.filter(id.eq(1))
.union_all(select(id).from(table).filter(id.eq(2)))
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"(select "users"."id" from "users" where "users"."id" = $1) union all (select "users"."id" from "users" where "users"."id" = $2); params=[1, 2]"#
);
}
#[test]
fn test_cross_dialect_function_helpers_emit_expected_sql() {
assert_bigint(&length(username));
assert_numeric_expr(&abs(id));
assert_numeric_expr(&round(id));
let postgres = select((
lower(username).alias("lower_name"),
upper(username).alias("upper_name"),
length(username).alias("name_len"),
abs(id).alias("abs_id"),
round(id).alias("rounded_id"),
coalesce(name, "fallback").alias("display_name"),
null_if(username, "root").alias("maybe_name"),
))
.from(table)
.to_debug_sql::<Postgres>();
assert_eq!(
postgres,
r#"select lower("users"."username") as "lower_name", upper("users"."username") as "upper_name", length("users"."username") as "name_len", abs("users"."id") as "abs_id", round("users"."id") as "rounded_id", coalesce("users"."name", $1) as "display_name", nullif("users"."username", $2) as "maybe_name" from "users"; params=["fallback", "root"]"#
);
let postgres_temporal = select((
current_date().alias("today"),
current_time().alias("clock"),
now().alias("current_time"),
))
.from(table)
.to_debug_sql::<Postgres>();
assert_eq!(
postgres_temporal,
r#"select current_date as "today", current_time as "clock", current_timestamp as "current_time" from "users"; params=[]"#
);
let sqlite = select((
lower(username).alias("lower_name"),
upper(username).alias("upper_name"),
length(username).alias("name_len"),
abs(id).alias("abs_id"),
round(id).alias("rounded_id"),
coalesce(name, "fallback").alias("display_name"),
null_if(username, "root").alias("maybe_name"),
))
.from(table)
.to_debug_sql::<Sqlite>();
assert_eq!(
sqlite,
r#"select lower("users"."username") as "lower_name", upper("users"."username") as "upper_name", length("users"."username") as "name_len", abs("users"."id") as "abs_id", round("users"."id") as "rounded_id", coalesce("users"."name", ?) as "display_name", nullif("users"."username", ?) as "maybe_name" from "users"; params=["fallback", "root"]"#
);
let sqlite_temporal = select((
current_date().alias("today"),
current_time().alias("clock"),
now().alias("current_time"),
))
.from(table)
.to_debug_sql::<Sqlite>();
assert_eq!(
sqlite_temporal,
r#"select current_date as "today", current_time as "clock", current_timestamp as "current_time" from "users"; params=[]"#
);
let mariadb = select((
lower(username).alias("lower_name"),
upper(username).alias("upper_name"),
length(username).alias("name_len"),
abs(id).alias("abs_id"),
round(id).alias("rounded_id"),
coalesce(name, "fallback").alias("display_name"),
null_if(username, "root").alias("maybe_name"),
))
.from(table)
.to_debug_sql::<MySql>();
assert_eq!(
mariadb,
r#"select lower(`users`.`username`) as `lower_name`, upper(`users`.`username`) as `upper_name`, length(`users`.`username`) as `name_len`, abs(`users`.`id`) as `abs_id`, round(`users`.`id`) as `rounded_id`, coalesce(`users`.`name`, ?) as `display_name`, nullif(`users`.`username`, ?) as `maybe_name` from `users`; params=["fallback", "root"]"#
);
let mariadb_temporal = select((
current_date().alias("today"),
current_time().alias("clock"),
now().alias("current_time"),
))
.from(table)
.to_debug_sql::<MySql>();
assert_eq!(
mariadb_temporal,
r#"select current_date as `today`, current_time as `clock`, current_timestamp as `current_time` from `users`; params=[]"#
);
}
#[test]
fn test_cross_dialect_temporal_extract_helpers_emit_expected_sql() {
let created_at = Column::<User, crate::Timestamp>::new("created_at");
let clock = quex::Time {
hour: 12,
minute: 34,
second: 56,
microsecond: 0,
};
let postgres = select((
created_at.date().alias("date_only"),
created_at.time().alias("time_only"),
created_at.year().alias("year_only"),
created_at.month().alias("month_only"),
created_at.day().alias("day_only"),
))
.from(table)
.filter(created_at.month().eq(4_i64))
.filter(created_at.time().gt(clock))
.to_debug_sql::<Postgres>();
assert_eq!(
postgres,
r#"select cast("users"."created_at" as date) as "date_only", cast("users"."created_at" as time) as "time_only", cast(extract(year from "users"."created_at") as bigint) as "year_only", cast(extract(month from "users"."created_at") as bigint) as "month_only", cast(extract(day from "users"."created_at") as bigint) as "day_only" from "users" where cast(extract(month from "users"."created_at") as bigint) = $1 and cast("users"."created_at" as time) > $2; params=[4, "12:34:56"]"#
);
let sqlite = select((
created_at.date().alias("date_only"),
created_at.time().alias("time_only"),
created_at.year().alias("year_only"),
created_at.month().alias("month_only"),
created_at.day().alias("day_only"),
))
.from(table)
.filter(created_at.month().eq(4_i64))
.filter(created_at.time().gt(clock))
.to_debug_sql::<Sqlite>();
assert_eq!(
sqlite,
r#"select date("users"."created_at") as "date_only", time("users"."created_at") as "time_only", cast(strftime('%Y', "users"."created_at") as integer) as "year_only", cast(strftime('%m', "users"."created_at") as integer) as "month_only", cast(strftime('%d', "users"."created_at") as integer) as "day_only" from "users" where cast(strftime('%m', "users"."created_at") as integer) = ? and time("users"."created_at") > ?; params=[4, "12:34:56"]"#
);
let mariadb = select((
created_at.date().alias("date_only"),
created_at.time().alias("time_only"),
created_at.year().alias("year_only"),
created_at.month().alias("month_only"),
created_at.day().alias("day_only"),
))
.from(table)
.filter(created_at.month().eq(4_i64))
.filter(created_at.time().gt(clock))
.to_debug_sql::<MySql>();
assert_eq!(
mariadb,
r#"select date(`users`.`created_at`) as `date_only`, time(`users`.`created_at`) as `time_only`, year(`users`.`created_at`) as `year_only`, month(`users`.`created_at`) as `month_only`, day(`users`.`created_at`) as `day_only` from `users` where month(`users`.`created_at`) = ? and time(`users`.`created_at`) > ?; params=[4, "12:34:56"]"#
);
}
#[test]
fn test_cross_dialect_temporal_relative_helpers_emit_expected_sql() {
let created_at = Column::<User, crate::Timestamp>::new("created_at");
let postgres = select(id)
.from(table)
.filter(created_at.past())
.filter(created_at.today().or(created_at.after_today()))
.to_debug_sql::<Postgres>();
assert_eq!(
postgres,
r#"select "users"."id" from "users" where "users"."created_at" < current_timestamp and (cast("users"."created_at" as date) = current_date or cast("users"."created_at" as date) > current_date); params=[]"#
);
let sqlite = select(id)
.from(table)
.filter(created_at.past())
.filter(created_at.today().or(created_at.after_today()))
.to_debug_sql::<Sqlite>();
assert_eq!(
sqlite,
r#"select "users"."id" from "users" where "users"."created_at" < current_timestamp and (date("users"."created_at") = current_date or date("users"."created_at") > current_date); params=[]"#
);
let mariadb = select(id)
.from(table)
.filter(created_at.past())
.filter(created_at.today().or(created_at.after_today()))
.to_debug_sql::<MySql>();
assert_eq!(
mariadb,
r#"select `users`.`id` from `users` where `users`.`created_at` < current_timestamp and (date(`users`.`created_at`) = current_date or date(`users`.`created_at`) > current_date); params=[]"#
);
}
}