use crate::dbms::query::{Filter, OrderDirection, Query};
use crate::prelude::{Join, JoinType};
#[derive(Debug, Default, Clone)]
pub struct QueryBuilder {
query: Query,
}
impl QueryBuilder {
pub fn build(self) -> Query {
self.query
}
pub fn field(mut self, field: &str) -> Self {
let field = field.to_string();
match &mut self.query.columns {
crate::dbms::query::Select::All => {
self.query.columns = crate::dbms::query::Select::Columns(vec![field]);
}
crate::dbms::query::Select::Columns(cols) if !cols.contains(&field) => {
cols.push(field);
}
_ => {}
}
self
}
pub fn fields<I>(mut self, fields: I) -> Self
where
I: IntoIterator<Item = &'static str>,
{
for field in fields {
self = self.field(field);
}
self
}
pub fn all(mut self) -> Self {
self.query.columns = crate::dbms::query::Select::All;
self
}
pub fn with(mut self, table_relation: &str) -> Self {
let table_relation = table_relation.to_string();
if !self.query.eager_relations.contains(&table_relation) {
self.query.eager_relations.push(table_relation);
}
self
}
pub fn inner_join(self, table: &str, left_col: &str, right_col: &str) -> Self {
self.join(JoinType::Inner, table, left_col, right_col)
}
pub fn left_join(self, table: &str, left_col: &str, right_col: &str) -> Self {
self.join(JoinType::Left, table, left_col, right_col)
}
pub fn right_join(self, table: &str, left_col: &str, right_col: &str) -> Self {
self.join(JoinType::Right, table, left_col, right_col)
}
pub fn full_join(self, table: &str, left_col: &str, right_col: &str) -> Self {
self.join(JoinType::Full, table, left_col, right_col)
}
pub fn order_by_asc(mut self, field: &str) -> Self {
self.query
.order_by
.push((field.to_string(), OrderDirection::Ascending));
self
}
pub fn order_by_desc(mut self, field: &str) -> Self {
self.query
.order_by
.push((field.to_string(), OrderDirection::Descending));
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.query.limit = Some(limit);
self
}
pub fn offset(mut self, offset: usize) -> Self {
self.query.offset = Some(offset);
self
}
pub fn filter(mut self, filter: Option<Filter>) -> Self {
self.query.filter = filter;
self
}
pub fn and_where(mut self, filter: Filter) -> Self {
self.query.filter = match self.query.filter {
Some(existing_filter) => Some(existing_filter.and(filter)),
None => Some(filter),
};
self
}
pub fn or_where(mut self, filter: Filter) -> Self {
self.query.filter = match self.query.filter {
Some(existing_filter) => Some(existing_filter.or(filter)),
None => Some(filter),
};
self
}
fn join(mut self, join_type: JoinType, table: &str, left_col: &str, right_col: &str) -> Self {
self.query.joins.push(Join {
join_type,
table: table.to_string(),
left_column: left_col.to_string(),
right_column: right_col.to_string(),
});
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dbms::value::Value;
use crate::tests::User;
#[test]
fn test_default_query_builder() {
let query_builder = QueryBuilder::default();
let query = query_builder.build();
assert!(matches!(query.columns, crate::dbms::query::Select::All));
assert!(query.eager_relations.is_empty());
assert!(query.filter.is_none());
assert!(query.order_by.is_empty());
assert!(query.limit.is_none());
assert!(query.offset.is_none());
}
#[test]
fn test_should_add_field_to_query_builder() {
let query_builder = QueryBuilder::default().field("id").field("name");
let query = query_builder.build();
assert_eq!(query.columns::<User>(), vec!["id", "name"]);
}
#[test]
fn test_should_set_fields() {
let query_builder = QueryBuilder::default().fields(["id", "email"]);
let query = query_builder.build();
assert_eq!(query.columns::<User>(), vec!["id", "email"]);
}
#[test]
fn test_should_set_all_fields() {
let query_builder = QueryBuilder::default().field("id").all();
let query = query_builder.build();
assert!(matches!(query.columns, crate::dbms::query::Select::All));
}
#[test]
fn test_should_add_eager_relation() {
let query_builder = QueryBuilder::default().with("posts");
let query = query_builder.build();
assert_eq!(query.eager_relations, vec!["posts"]);
}
#[test]
fn test_should_not_duplicate_eager_relation() {
let query_builder = QueryBuilder::default().with("posts").with("posts");
let query = query_builder.build();
assert_eq!(query.eager_relations, vec!["posts"]);
}
#[test]
fn test_should_add_order_by_clauses() {
let query_builder = QueryBuilder::default()
.order_by_asc("name")
.order_by_desc("created_at");
let query = query_builder.build();
assert_eq!(
query.order_by,
vec![
("name".to_string(), OrderDirection::Ascending),
("created_at".to_string(), OrderDirection::Descending)
]
);
}
#[test]
fn test_should_set_limit_and_offset() {
let query_builder = QueryBuilder::default().limit(10).offset(5);
let query = query_builder.build();
assert_eq!(query.limit, Some(10));
assert_eq!(query.offset, Some(5));
}
#[test]
fn test_should_create_filters() {
let query = QueryBuilder::default()
.all()
.and_where(Filter::eq("id", Value::Uint32(1u32.into())))
.or_where(Filter::like("name", "John%"))
.build();
let filter = query.filter.expect("should have filter");
if let Filter::Or(left, right) = filter {
assert!(matches!(*left, Filter::Eq(id, Value::Uint32(_)) if id == "id"));
assert!(matches!(*right, Filter::Like(name, _) if name == "name"));
} else {
panic!("Expected OR filter at the top level");
}
}
#[test]
fn test_should_add_inner_join() {
let query = QueryBuilder::default()
.all()
.inner_join("posts", "id", "user")
.build();
assert_eq!(query.joins.len(), 1);
assert_eq!(
query.joins[0].join_type,
crate::dbms::query::JoinType::Inner
);
assert_eq!(query.joins[0].table, "posts");
assert_eq!(query.joins[0].left_column, "id");
assert_eq!(query.joins[0].right_column, "user");
}
#[test]
fn test_should_add_left_join() {
let query = QueryBuilder::default()
.all()
.left_join("posts", "id", "user")
.build();
assert_eq!(query.joins[0].join_type, crate::dbms::query::JoinType::Left);
}
#[test]
fn test_should_add_right_join() {
let query = QueryBuilder::default()
.all()
.right_join("posts", "id", "user")
.build();
assert_eq!(
query.joins[0].join_type,
crate::dbms::query::JoinType::Right
);
}
#[test]
fn test_should_add_full_join() {
let query = QueryBuilder::default()
.all()
.full_join("posts", "id", "user")
.build();
assert_eq!(query.joins[0].join_type, crate::dbms::query::JoinType::Full);
}
#[test]
fn test_should_chain_multiple_joins() {
let query = QueryBuilder::default()
.all()
.inner_join("posts", "id", "user")
.left_join("comments", "posts.id", "post_id")
.build();
assert_eq!(query.joins.len(), 2);
}
}