use std::marker::PhantomData;
use field_access::FieldAccess;
use sqlx::{Database, Encode, QueryBuilder, Type};
use crate::common::{helper::get_table_name, types::JoinType};
enum SubqueryPart<VAL> {
Text(String),
Bind(VAL),
}
pub struct Subquery<'a, ET, VAL>
where
ET: FieldAccess + Default,
VAL: 'a,
{
parts: Vec<SubqueryPart<VAL>>,
table_name: String,
has_from: bool,
has_filter: bool,
has_group_by: bool,
has_having: bool,
_phantom: PhantomData<(ET, &'a ())>,
}
impl<'a, ET, VAL> Subquery<'a, ET, VAL>
where
ET: FieldAccess + Default,
VAL: 'a,
{
pub fn table() -> Self {
Self::with_table(&get_table_name::<ET>())
}
pub fn with_table(table_name: &str) -> Self {
Self {
parts: vec![SubqueryPart::Text("SELECT ".to_string())],
table_name: table_name.to_string(),
has_from: false,
has_filter: false,
has_group_by: false,
has_having: false,
_phantom: PhantomData,
}
}
fn push_part(&mut self, f: impl FnOnce(&mut SubqueryBuilder<'_, VAL>)) {
if let Some(SubqueryPart::Text(last)) = self.parts.last_mut() {
if !last.ends_with(' ') {
*last = format!("{} ", last);
}
}
let mut builder = SubqueryBuilder { parts: &mut self.parts };
f(&mut builder)
}
pub fn columns(
mut self,
column_build_fn: impl FnOnce(&mut SubqueryBuilder<'_, VAL>)
) -> Self {
if self.has_from {
return self;
}
self.push_part(column_build_fn);
self.parts.push(SubqueryPart::Text(format!(" FROM {}", &self.table_name)));
self.has_from = true;
self
}
fn add_from_clause(&mut self) {
let columns = ET::default().field_names().join(", ");
self.parts.push(SubqueryPart::Text(format!("{} FROM {}", columns, &self.table_name)));
self.has_from = true;
}
pub fn filter(mut self, f: impl FnOnce(&mut SubqueryBuilder<'_, VAL>)) -> Self {
if !self.has_from {
self.add_from_clause();
}
if !self.has_filter {
self.parts.push(SubqueryPart::Text(" WHERE ".to_string()));
}
self.push_part(f);
self
}
pub fn join(
mut self,
join_type: JoinType,
table: impl Into<String>,
on_condition: impl FnOnce(&mut SubqueryBuilder<'_, VAL>),
) -> Self {
if !self.has_from {
self.add_from_clause();
}
let join_keyword = match join_type {
JoinType::Inner => "INNER JOIN",
JoinType::Left => "LEFT JOIN",
JoinType::Right => "RIGHT JOIN",
JoinType::Full => "FULL JOIN",
JoinType::Cross => "CROSS JOIN",
};
self.parts.push(SubqueryPart::Text(table.into()));
self.parts.push(SubqueryPart::Text(join_keyword.to_string()));
self.push_part(on_condition);
self
}
pub fn group_by(mut self, field: impl Into<String>) -> Self {
if !self.has_from {
self.add_from_clause();
}
let field = field.into();
if self.has_group_by {
self.parts.push(SubqueryPart::Text(", ".into()));
} else {
self.parts.push(SubqueryPart::Text(" GROUP BY ".into()));
self.has_group_by = true;
}
self.parts.push(SubqueryPart::Text(field.into()));
self
}
pub fn having(
mut self,
condition: impl FnOnce(&mut SubqueryBuilder<'_, VAL>),
) -> Self {
if !self.has_group_by {
return self;
}
if !self.has_having {
self.parts.push(SubqueryPart::Text(" HAVING ".into()));
self.has_having = true;
}
self.push_part(condition);
self
}
pub fn append_to<DB>(mut self, query_builder: &mut QueryBuilder<'a, DB>)
where
VAL: Encode<'a, DB> + Type<DB>,
DB: Database,
{
query_builder.push(" (");
if !self.has_from {
self.add_from_clause();
}
for part in self.parts {
match part {
SubqueryPart::Text(text) => query_builder.push(&text),
SubqueryPart::Bind(val) => query_builder.push_bind(val),
};
}
query_builder.push(") ");
}
}
pub struct SubqueryBuilder<'a, VAL> {
parts: &'a mut Vec<SubqueryPart<VAL>>,
}
impl<'a, VAL> SubqueryBuilder<'a, VAL> {
pub fn push(&mut self, s: &str) -> &mut Self {
if let Some(SubqueryPart::Text(last)) = self.parts.last_mut() {
*last = format!("{}{}", last, s);
return self;
}
self.parts.push(SubqueryPart::Text(s.to_string()));
self
}
pub fn push_bind(&mut self, val: VAL) -> &mut Self {
self.parts.push(SubqueryPart::Bind(val));
self
}
}