use crate::{
db::{StatementContext, Transaction},
schema::{
datum::{BorrowedDatum, BorrowedDatumList, Datum},
entity::{Entity, EntityID, EntityPart, EntityPartList, EntityRef},
index::Index,
relation::{LocalSide, RelationData},
Borrowed, Stored,
},
DBResult, Error,
};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
pub(crate) mod base_queries;
pub(crate) mod components;
pub(crate) mod containers;
use containers::*;
#[derive(Debug, Clone)]
pub(crate) enum QueryPartData<'l> {
Owned(String),
Borrowed(&'l str),
}
impl std::fmt::Display for QueryPartData<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Owned(s) => f.write_str(s),
Self::Borrowed(s) => f.write_str(s),
}
}
}
impl From<String> for QueryPartData<'_> {
fn from(value: String) -> Self {
Self::Owned(value)
}
}
impl<'l> From<&'l str> for QueryPartData<'l> {
fn from(value: &'l str) -> Self {
Self::Borrowed(value)
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub(crate) enum QueryPart {
Root,
Columns,
From,
Set,
Join,
Where,
Order,
Limit,
Trailing,
}
trait Joinable {
fn join(self, sep: &str) -> String;
}
impl<T: std::fmt::Display, I: Iterator<Item = T>> Joinable for I {
fn join(self, sep: &str) -> String {
use std::fmt::Write;
let mut out = String::new();
let mut first = true;
for item in self {
if first {
first = false
} else {
out.push_str(sep);
}
write!(&mut out, "{}", item).unwrap();
}
out
}
}
#[derive(Debug)]
pub struct Query<'l> {
parts: HashMap<QueryPart, Vec<QueryPartData<'l>>>,
}
impl<'l> Query<'l> {
pub(crate) fn new() -> Self {
Self {
parts: Default::default(),
}
}
pub(crate) fn attach<T: Into<QueryPartData<'l>>>(mut self, qp: QueryPart, val: T) -> Self {
self.attach_mut(qp, val.into());
self
}
pub(crate) fn replace<T: Into<QueryPartData<'l>>>(mut self, qp: QueryPart, val: T) -> Self {
self.parts.remove(&qp);
self.attach(qp, val.into())
}
pub(crate) fn attach_mut<T: Into<QueryPartData<'l>>>(&mut self, qp: QueryPart, val: T) {
self.parts.entry(qp).or_default().push(val.into());
}
pub(crate) fn assemble(mut self) -> String {
let root = self.parts.remove(&QueryPart::Root).unwrap().remove(0);
let columns_ = match self.parts.remove(&QueryPart::Columns) {
None => String::new(),
Some(v) => v.into_iter().join(","),
};
let from_ = match self.parts.remove(&QueryPart::From) {
None => String::new(),
Some(v) => {
format!("FROM {}", v.into_iter().join(","))
},
};
let set_ = match self.parts.remove(&QueryPart::Set) {
None => String::new(),
Some(v) => {
format!("SET {}", v.into_iter().join(","))
},
};
let join_ = match self.parts.remove(&QueryPart::Join) {
None => String::new(),
Some(v) => v
.into_iter()
.map(|j| format!("INNER JOIN {}", j))
.reduce(|a, b| format!("{} {}", a, b))
.unwrap(),
};
let where_ = match self.parts.remove(&QueryPart::Where) {
None => String::new(),
Some(v) => {
format!("WHERE {}", v.into_iter().join(" AND "))
},
};
let order_ = match self.parts.remove(&QueryPart::Order) {
None => String::new(),
Some(v) => v.into_iter().join(" "),
};
let limit_ = match self.parts.remove(&QueryPart::Limit) {
None => String::new(),
Some(v) => v.into_iter().join(" "),
};
let trailing_ = match self.parts.remove(&QueryPart::Trailing) {
None => String::new(),
Some(v) => v.into_iter().join(" "),
};
format!("{root} {columns_} {from_} {set_} {join_} {where_} {order_} {limit_} {trailing_}")
}
}
pub(crate) struct RelationNames {
local_name: &'static str,
remote_name: &'static str,
part_name: &'static str,
dist_name: &'static str,
domain_name: &'static str,
range_name: &'static str,
local_field: &'static str,
remote_field: &'static str,
}
impl RelationNames {
fn collect<AI: RelationInterface>(iface: &AI) -> DBResult<RelationNames> {
let rdata = iface.get_data()?;
let local_name = rdata.local_name;
let remote_name = <AI::RemoteEntity>::entity_name();
let part_name = rdata.part_name;
let dist_name = iface.get_distinguishing_name()?;
let (domain_name, range_name) = match AI::SIDE {
LocalSide::Domain => (local_name, remote_name),
LocalSide::Range => (remote_name, local_name),
};
let (local_field, remote_field) = match AI::SIDE {
LocalSide::Domain => ("domain", "range"),
LocalSide::Range => ("range", "domain"),
};
Ok(Self {
local_name,
remote_name,
part_name,
dist_name,
domain_name,
range_name,
local_field,
remote_field,
})
}
fn relation_name(&self) -> String {
format!(
"{domain_name}_{range_name}_relation_{dist_name}",
domain_name = self.domain_name,
range_name = self.range_name,
dist_name = self.dist_name
)
}
}
fn hash_of<T: Hash>(val: T) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
val.hash(&mut hasher);
hasher.finish()
}
pub trait RelationInterface {
type RemoteEntity: Entity;
type StaticVersion: RelationInterface<RemoteEntity = Self::RemoteEntity> + 'static;
#[doc(hidden)]
fn get_data(&self) -> DBResult<&RelationData>;
#[doc(hidden)]
fn get_distinguishing_name(&self) -> DBResult<&'static str>;
const SIDE: LocalSide;
fn query_all(&self) -> impl Queryable<EntityOutput = Self::RemoteEntity> {
components::TableComponent::<Self::RemoteEntity>::new()
}
fn connect_to(
&self,
txn: &mut Transaction,
remote_id: <Self::RemoteEntity as Entity>::ID,
) -> DBResult<()>
where
Self: Sized,
{
let rdata = self.get_data()?;
let an = RelationNames::collect::<Self>(self)?;
base_queries::do_connect::<Self::RemoteEntity>(txn, rdata, an, remote_id)
}
fn disconnect_from(
&self,
txn: &mut Transaction,
remote_id: <Self::RemoteEntity as Entity>::ID,
) -> DBResult<()>
where
Self: Sized,
{
let rdata = self.get_data()?;
let an = RelationNames::collect::<Self>(self)?;
txn.lease().with_prepared(
hash_of(("disconnect", an.local_name, an.remote_name, an.part_name)),
|| {
Ok(format!(
"delete from `{relation_name}` where `{local_field}` = ? and `{remote_field}` = ?",
relation_name = an.relation_name(),
local_field = an.local_field,
remote_field = an.remote_field
))
},
|ctx| {
ctx.bind(1, rdata.local_id)?;
ctx.bind(2, remote_id.into_raw())?;
ctx.run().map(|_| ())
},
)
}
}
pub trait Insertable<E: Entity> {
fn insert(&self, txn: &mut Transaction, value: E) -> DBResult<E::ID>;
fn insert_ref(&self, txn: &mut Transaction, value: E::ERef<'_>) -> DBResult<E::ID>;
fn insert_and_return(&self, txn: &mut Transaction, value: E) -> DBResult<Stored<E>>;
}
impl<AI: RelationInterface> Insertable<AI::RemoteEntity> for AI {
fn insert(
&self,
txn: &mut Transaction,
value: AI::RemoteEntity,
) -> DBResult<<AI::RemoteEntity as Entity>::ID>
where
Self: Sized,
{
let rdata = self.get_data()?;
let an = RelationNames::collect::<Self>(self)?;
let remote_id = base_queries::insert(txn, &value)?;
base_queries::do_connect::<AI::RemoteEntity>(txn, rdata, an, remote_id)?;
Ok(remote_id)
}
fn insert_ref(
&self,
txn: &mut Transaction,
value: <AI::RemoteEntity as Entity>::ERef<'_>,
) -> DBResult<<AI::RemoteEntity as Entity>::ID> {
let rdata = self.get_data()?;
let an = RelationNames::collect::<Self>(self)?;
let remote_id = base_queries::insert_ref::<AI::RemoteEntity>(txn, value)?;
base_queries::do_connect::<AI::RemoteEntity>(txn, rdata, an, remote_id)?;
Ok(remote_id)
}
fn insert_and_return(
&self,
txn: &mut Transaction,
value: AI::RemoteEntity,
) -> DBResult<Stored<AI::RemoteEntity>>
where
Self: Sized,
{
let rdata = self.get_data()?;
let an = RelationNames::collect::<Self>(self)?;
let remote = base_queries::insert_and_return(txn, value)?;
base_queries::do_connect::<AI::RemoteEntity>(txn, rdata, an, remote.id())?;
Ok(remote)
}
}
pub trait Queryable: Clone {
type EntityOutput: Entity;
type OutputContainer: OutputContainer<Self::EntityOutput>;
type StaticVersion: Queryable<EntityOutput = Self::EntityOutput> + 'static;
#[doc(hidden)]
fn build(&self) -> DBResult<Query<'_>>;
#[doc(hidden)]
fn bind(&self, stmt: &mut StatementContext, index: &mut i32) -> DBResult<()>;
fn count(self, txn: &mut Transaction) -> DBResult<usize>
where
Self: Sized,
{
struct CountTag;
txn.lease().with_prepared(
std::any::TypeId::of::<(Self::StaticVersion, CountTag)>(),
|| {
Ok(self
.build()?
.replace(
QueryPart::Columns,
format!(
"COUNT(DISTINCT `{}`.`id`)",
Self::EntityOutput::entity_name()
),
)
.assemble())
},
|mut ctx| {
let mut index = 1;
self.bind(&mut ctx, &mut index)?;
Ok(ctx
.run()?
.ok_or(Error::InternalError("no resulting rows from COUNT query"))?
.read::<i64>(0)? as usize)
},
)
}
fn get(self, txn: &mut Transaction) -> DBResult<Self::OutputContainer>
where
Self: Sized,
{
struct GetTag;
txn.lease().with_prepared(
std::any::TypeId::of::<(Self::StaticVersion, GetTag)>(),
|| Ok(self.build()?.assemble()),
|mut ctx| {
let mut index = 1;
self.bind(&mut ctx, &mut index)?;
<Self::OutputContainer>::assemble_from(ctx)
},
)
}
fn iter(
self,
txn: &mut Transaction,
) -> DBResult<impl Iterator<Item = DBResult<Stored<Self::EntityOutput>>>>
where
Self: Sized,
{
struct IterTag;
Ok(txn
.lease()
.iter_with_prepared(
self,
std::any::TypeId::of::<(Self::StaticVersion, IterTag)>(),
|q| Ok(q.build()?.assemble()),
|q, ctx| {
let mut index = 1;
q.bind(ctx, &mut index)
},
)?
.map(|row| {
let mut row = row?;
let id = row.read::<i64>(0).expect("couldn't read ID");
let datum_list =
<<Self::EntityOutput as Entity>::Parts>::build_datum_list(&mut row)
.expect("couldn't build datum list");
Ok(Stored::new(
<Self::EntityOutput as Entity>::ID::from_raw(id),
<Self::EntityOutput as Entity>::build(datum_list),
))
}))
}
fn iter_refs<'a>(
self,
txn: &'a mut Transaction,
mut f: impl for<'b> FnMut(
Borrowed<'b, <Self::EntityOutput as Entity>::ERef<'b>>,
) -> std::ops::ControlFlow<()>,
) -> DBResult<()>
where
Self: Sized,
{
struct IterRefTag;
let mut early_break = false;
txn.lease()
.iter_with_prepared(
self,
std::any::TypeId::of::<(Self::StaticVersion, IterRefTag)>(),
|q| Ok(q.build()?.assemble()),
|q, ctx| {
let mut index = 1;
q.bind(ctx, &mut index)
},
)?
.try_for_each::<_, DBResult<()>>(|row| {
if early_break {
return Ok(());
}
let mut row = row?;
let id = row.read::<i64>(0).expect("couldn't read ID");
let datum_list =
<<Self::EntityOutput as Entity>::Parts>::build_datum_ref_list(&mut row)?;
let r =
<<Self::EntityOutput as Entity>::ERef<'_> as EntityRef<'_>>::from_borrowed_list(
datum_list,
);
match f(Borrowed::new(
<Self::EntityOutput as Entity>::ID::from_raw(id),
r,
)) {
std::ops::ControlFlow::Continue(_) => Ok(()),
std::ops::ControlFlow::Break(_) => {
early_break = true;
Ok(())
},
}
})
}
fn get_ids(
self,
txn: &mut Transaction,
) -> DBResult<<Self::OutputContainer as OutputContainer<Self::EntityOutput>>::IDContainer>
where
Self: Sized,
{
struct GetIDTag;
txn.lease().with_prepared(
std::any::TypeId::of::<(Self::StaticVersion, GetIDTag)>(),
|| {
Ok(self
.build()?
.replace(
QueryPart::Columns,
format!("`{}`.`id`", Self::EntityOutput::entity_name()),
)
.assemble())
},
|mut ctx| {
let mut index = 1;
self.bind(&mut ctx, &mut index)?;
<<Self::OutputContainer as OutputContainer<
Self::EntityOutput,
>>::IDContainer>::assemble_from(ctx)
},
)
}
fn delete(self, txn: &mut Transaction) -> DBResult<()>
where
Self: Sized,
{
struct DeleteTag;
txn.lease().with_prepared(
std::any::TypeId::of::<(Self::StaticVersion, DeleteTag)>(),
|| {
Ok(format!(
"DELETE FROM `{}` WHERE `id` = ({})",
Self::EntityOutput::entity_name(),
self.build()?
.replace(
QueryPart::Columns,
format!("`{}`.`id`", Self::EntityOutput::entity_name())
)
.assemble()
))
},
|mut ctx| {
let mut index = 1;
self.bind(&mut ctx, &mut index)?;
ctx.run()?;
Ok(())
},
)
}
fn remove(self, txn: &mut Transaction) -> DBResult<Self::OutputContainer>
where
Self: Sized,
{
struct DeleteTag;
txn.lease().with_prepared(
std::any::TypeId::of::<(Self::StaticVersion, DeleteTag)>(),
|| {
Ok(format!(
"DELETE FROM `{entity}` WHERE `id` = ({subquery}) RETURNING *",
entity = Self::EntityOutput::entity_name(),
subquery = self
.build()?
.replace(
QueryPart::Columns,
format!("`{}`.`id`", Self::EntityOutput::entity_name())
)
.assemble()
))
},
|mut ctx| {
let mut index = 1;
self.bind(&mut ctx, &mut index)?;
<Self::OutputContainer>::assemble_from(ctx)
},
)
}
fn keyed<'l>(
self,
values: impl BorrowedDatumList<
'l,
<<Self::EntityOutput as Entity>::Keys as EntityPartList>::DatumList,
>,
) -> impl Queryable<
EntityOutput = Self::EntityOutput,
OutputContainer = Option<Stored<Self::EntityOutput>>,
>
where
Self: Sized,
{
components::IndexComponent::<_, _, <Self::EntityOutput as Entity>::Keys, _>::new(
self, values,
)
}
fn indexed<'l, EPL: EntityPartList<Entity = Self::EntityOutput>>(
self,
_index: &Index<true, Self::EntityOutput, EPL>,
values: impl BorrowedDatumList<'l, EPL::DatumList>,
) -> impl Queryable<
EntityOutput = Self::EntityOutput,
OutputContainer = Option<Stored<Self::EntityOutput>>,
>
where
Self: Sized,
{
components::IndexComponent::<_, _, EPL, _>::new(self, values)
}
fn with<'l, EP: EntityPart<Entity = Self::EntityOutput>>(
self,
part: EP,
value: impl BorrowedDatum<'l, EP::Datum>,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
where
Self: Sized,
{
components::WithComponent::new(self, part, value)
}
fn with_id(
self,
id: <Self::EntityOutput as Entity>::ID,
) -> impl Queryable<
EntityOutput = Self::EntityOutput,
OutputContainer = Option<Stored<Self::EntityOutput>>,
>
where
Self: Sized,
{
self.with(<Self::EntityOutput as Entity>::IDPart::default(), id)
.first()
}
fn filter_lt<'l, EPL: EntityPartList<Entity = Self::EntityOutput>>(
self,
_parts: EPL,
values: impl BorrowedDatumList<'l, EPL::DatumList>,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
{
components::FilterComponent::<_, EPL, 0, _>::new(self, values)
}
fn filter_lte<'l, EPL: EntityPartList<Entity = Self::EntityOutput>>(
self,
_parts: EPL,
values: impl BorrowedDatumList<'l, EPL::DatumList>,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
{
components::FilterComponent::<_, EPL, 1, _>::new(self, values)
}
fn filter_eq<'l, EPL: EntityPartList<Entity = Self::EntityOutput>>(
self,
_parts: EPL,
values: impl BorrowedDatumList<'l, EPL::DatumList>,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
{
components::FilterComponent::<_, EPL, 2, _>::new(self, values)
}
fn filter_gte<'l, EPL: EntityPartList<Entity = Self::EntityOutput>>(
self,
_parts: EPL,
values: impl BorrowedDatumList<'l, EPL::DatumList>,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
{
components::FilterComponent::<_, EPL, 3, _>::new(self, values)
}
fn filter_gt<'l, EPL: EntityPartList<Entity = Self::EntityOutput>>(
self,
_parts: EPL,
values: impl BorrowedDatumList<'l, EPL::DatumList>,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
{
components::FilterComponent::<_, EPL, 4, _>::new(self, values)
}
fn filter_glob<EP: EntityPart<Entity = Self::EntityOutput, Datum = String>>(
self,
_part: EP,
filter: &str,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
{
components::FilterComponent::<_, EP, 5, _>::new(self, filter)
}
#[cfg(feature = "regex")]
fn filter_regex<'l, EP: EntityPart<Entity = Self::EntityOutput, Datum = String>>(
self,
_part: EP,
filter: &'l str,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
{
components::FilterComponent::<_, EP, 6, _>::new(self, filter)
}
fn first(
self,
) -> impl Queryable<
EntityOutput = Self::EntityOutput,
OutputContainer = Option<Stored<Self::EntityOutput>>,
>
where
Self: Sized,
{
components::SingleComponent::new(self)
}
fn order_by_asc<EPL: EntityPartList<Entity = Self::EntityOutput>>(
self,
part: EPL,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
where
Self: Sized,
{
components::OrderByComponent::<_, _, true>::new(self, part)
}
fn order_by_desc<EPL: EntityPartList<Entity = Self::EntityOutput>>(
self,
part: EPL,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
where
Self: Sized,
{
components::OrderByComponent::<_, _, false>::new(self, part)
}
fn limit(
self,
limit: usize,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
where
Self: Sized,
{
components::LimitComponent::new(self, limit)
}
fn offset_limit(
self,
limit: usize,
offset: usize,
) -> impl Queryable<EntityOutput = Self::EntityOutput, OutputContainer = Self::OutputContainer>
where
Self: Sized,
{
components::LimitComponent::new_with_offset(self, limit, offset)
}
fn join<
AD: RelationInterface + Datum,
EP: EntityPart<Entity = Self::EntityOutput, Datum = AD>,
>(
self,
part: EP,
) -> impl Queryable<EntityOutput = AD::RemoteEntity, OutputContainer = Vec<Stored<AD::RemoteEntity>>>
where
Self: Sized,
{
components::JoinComponent::<AD::RemoteEntity, Self::EntityOutput, _, Self>::new(self, part)
}
#[allow(clippy::type_complexity)]
fn foreign<EP: EntityPart<Entity = Self::EntityOutput>>(
self,
part: EP,
) -> impl Queryable<
EntityOutput = <EP::Datum as EntityID>::Entity,
OutputContainer = <Self::OutputContainer as OutputContainer<
Self::EntityOutput,
>>::WithReplacedEntity<<EP::Datum as EntityID>::Entity>,
>
where
Self: Sized,
EP::Datum: EntityID,
{
components::ForeignComponent::<_, EP, Self>::new(self, part)
}
}
impl<AI: RelationInterface> Queryable for &AI {
type EntityOutput = AI::RemoteEntity;
type OutputContainer = Vec<Stored<AI::RemoteEntity>>;
type StaticVersion = &'static AI::StaticVersion;
fn build(&self) -> DBResult<Query<'_>> {
let anames = RelationNames::collect(*self).unwrap();
let relation_name = anames.relation_name();
Ok(Query::new()
.attach(QueryPart::Root, "SELECT DISTINCT")
.attach(QueryPart::Columns, format!("`{}`.*", anames.remote_name))
.attach(QueryPart::From, format!("`{}`", relation_name))
.attach(
QueryPart::Join,
format!(
"`{}` ON `{}`.`id` = `{}`.`{}`",
anames.remote_name, anames.remote_name, relation_name, anames.remote_field
),
)
.attach(
QueryPart::Where,
format!("`{}`.`{}` = ?", relation_name, anames.local_field),
))
}
fn bind(&self, ctx: &mut StatementContext, index: &mut i32) -> DBResult<()> {
let rdata = self
.get_data()
.expect("binding query for relation with no data");
ctx.bind(*index, rdata.local_id)?;
*index += 1;
Ok(())
}
}
impl<E: Entity, EPL: EntityPartList<Entity = E>> Index<true, E, EPL> {
pub fn search<'a>(
&'a self,
values: impl BorrowedDatumList<'a, EPL::DatumList> + 'a,
) -> impl 'a + Queryable<EntityOutput = E, OutputContainer = Option<Stored<E>>> {
self.indexed(self, values)
}
}
impl<const UNIQUE: bool, E: Entity, EPL: EntityPartList<Entity = E>> Queryable
for &Index<UNIQUE, E, EPL>
{
type EntityOutput = E;
type OutputContainer = Vec<Stored<E>>;
type StaticVersion = &'static Index<UNIQUE, E, EPL>;
fn build(&self) -> DBResult<Query<'_>> {
Ok(Query::new()
.attach(QueryPart::Root, "SELECT DISTINCT")
.attach(QueryPart::Columns, "*")
.attach(QueryPart::From, format!("`{}`", E::entity_name())))
}
fn bind(&self, _stmt: &mut StatementContext, _index: &mut i32) -> DBResult<()> {
Ok(())
}
}