use std::{cmp::Ordering, collections::HashMap, marker::PhantomData, str::FromStr};
use rustrails_support::{database, runtime};
use sea_orm::{
ColumnTrait, Condition, ConnectionTrait, DatabaseConnection, DbBackend, EntityTrait, ExprTrait,
FromQueryResult, IdenStatic, Iterable, Order, QueryFilter, QueryOrder, QuerySelect, QueryTrait,
Select, Statement, Value,
sea_query::{ColumnType, Expr, SimpleExpr},
};
use serde_json::{Value as JsonValue, json};
use crate::{OrderDirection, Record, RecordError, RecordState};
#[derive(Debug, Clone, Default)]
struct ConditionGroup {
positive: Vec<(String, JsonValue)>,
negative: Vec<(String, JsonValue)>,
}
impl ConditionGroup {
fn from_positive(conditions: HashMap<String, JsonValue>) -> Self {
Self {
positive: conditions.into_iter().collect(),
negative: Vec::new(),
}
}
fn is_empty(&self) -> bool {
self.positive.is_empty() && self.negative.is_empty()
}
}
#[derive(Debug)]
pub struct Relation<T: Record> {
where_groups: Vec<ConditionGroup>,
order_by: Vec<(String, OrderDirection)>,
group_columns: Vec<String>,
having_conditions: Vec<String>,
is_distinct: bool,
select_columns: Vec<String>,
join_associations: Vec<String>,
included_associations: Vec<String>,
limit_val: Option<u64>,
offset_val: Option<u64>,
_phantom: PhantomData<T>,
}
impl<T: Record> Clone for Relation<T> {
fn clone(&self) -> Self {
Self {
where_groups: self.where_groups.clone(),
order_by: self.order_by.clone(),
group_columns: self.group_columns.clone(),
having_conditions: self.having_conditions.clone(),
is_distinct: self.is_distinct,
select_columns: self.select_columns.clone(),
join_associations: self.join_associations.clone(),
included_associations: self.included_associations.clone(),
limit_val: self.limit_val,
offset_val: self.offset_val,
_phantom: PhantomData,
}
}
}
impl<T: Record> Default for Relation<T> {
fn default() -> Self {
Self {
where_groups: vec![ConditionGroup::default()],
order_by: Vec::new(),
group_columns: Vec::new(),
having_conditions: Vec::new(),
is_distinct: false,
select_columns: Vec::new(),
join_associations: Vec::new(),
included_associations: Vec::new(),
limit_val: None,
offset_val: None,
_phantom: PhantomData,
}
}
}
impl<T: Record> Relation<T> {
pub fn new() -> Self {
Self::default()
}
pub fn r#where(mut self, conditions: HashMap<String, JsonValue>) -> Self {
self.add_positive_conditions(conditions);
self
}
pub fn order(mut self, column: &str, dir: OrderDirection) -> Self {
self.order_by.push((column.to_owned(), dir));
self
}
pub fn group(mut self, column: &str) -> Self {
self.group_columns.push(column.to_owned());
self
}
pub fn having(mut self, condition: &str) -> Self {
self.having_conditions.push(condition.to_owned());
self
}
pub fn distinct(mut self) -> Self {
self.is_distinct = true;
self
}
pub fn select_columns(mut self, columns: &[&str]) -> Self {
self.select_columns
.extend(columns.iter().map(|column| (*column).to_owned()));
self
}
pub fn select(self, columns: &[&str]) -> Self {
self.select_columns(columns)
}
pub fn joins(mut self, association: &str) -> Self {
self.join_associations.push(association.to_owned());
self
}
pub fn includes(mut self, association: &str) -> Self {
self.included_associations.push(association.to_owned());
self
}
pub fn reorder(mut self, column: &str, dir: OrderDirection) -> Self {
self.order_by.clear();
self.order_by.push((column.to_owned(), dir));
self
}
pub fn reselect(mut self, columns: &[&str]) -> Self {
self.select_columns.clear();
self.select_columns
.extend(columns.iter().map(|column| (*column).to_owned()));
self
}
pub fn rewhere(mut self, conditions: HashMap<String, JsonValue>) -> Self {
self.where_groups = vec![ConditionGroup::from_positive(conditions)];
self
}
pub fn limit(mut self, n: u64) -> Self {
self.limit_val = Some(n);
self
}
pub fn offset(mut self, n: u64) -> Self {
self.offset_val = Some(n);
self
}
pub fn not_where(mut self, conditions: HashMap<String, JsonValue>) -> Self {
self.add_negative_conditions(conditions);
self
}
pub fn not(self, conditions: HashMap<String, JsonValue>) -> Self {
self.not_where(conditions)
}
pub fn or_where(mut self, conditions: HashMap<String, JsonValue>) -> Self {
if conditions.is_empty() {
return self;
}
let group = ConditionGroup::from_positive(conditions);
if self.where_groups.len() == 1 && self.where_groups[0].is_empty() {
self.where_groups[0] = group;
} else {
self.where_groups.push(group);
}
self
}
fn add_positive_conditions(&mut self, conditions: HashMap<String, JsonValue>) {
if conditions.is_empty() {
return;
}
let entries = conditions.into_iter().collect::<Vec<_>>();
if self.where_groups.is_empty() {
self.where_groups.push(ConditionGroup::default());
}
for group in &mut self.where_groups {
group.positive.extend(entries.iter().cloned());
}
}
fn add_negative_conditions(&mut self, conditions: HashMap<String, JsonValue>) {
if conditions.is_empty() {
return;
}
let entries = conditions.into_iter().collect::<Vec<_>>();
if self.where_groups.is_empty() {
self.where_groups.push(ConditionGroup::default());
}
for group in &mut self.where_groups {
group.negative.extend(entries.iter().cloned());
}
}
pub async fn load(&self, db: &DatabaseConnection) -> Result<Vec<T>, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let models = self.build_select()?.all(db).await?;
Ok(models
.into_iter()
.map(|model| {
let mut record = T::from_sea_model(model);
record.set_record_state(RecordState::Persisted);
record
})
.collect())
}
pub fn load_sync(&self) -> Result<Vec<T>, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.load(db)))
}
pub async fn find_each<F>(
&self,
batch_size: u64,
db: &DatabaseConnection,
mut f: F,
) -> Result<(), RecordError>
where
F: FnMut(T),
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
self.find_in_batches(batch_size, db, |batch| {
for record in batch {
f(record);
}
})
.await
}
pub fn find_each_sync<F>(&self, batch_size: u64, mut f: F) -> Result<(), RecordError>
where
F: FnMut(T),
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.find_each(batch_size, db, &mut f)))
}
pub async fn find_in_batches<F>(
&self,
batch_size: u64,
db: &DatabaseConnection,
mut f: F,
) -> Result<(), RecordError>
where
F: FnMut(Vec<T>),
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
if batch_size == 0 {
return Err(RecordError::Invalid(
"batch size must be greater than zero".to_owned(),
));
}
let mut offset = self.offset_val.unwrap_or(0);
let mut remaining = self.limit_val;
loop {
let current_batch_size =
remaining.map_or(batch_size, |limit| Ord::min(limit, batch_size));
if current_batch_size == 0 {
break;
}
let batch = Self::clone(self)
.limit(current_batch_size)
.offset(offset)
.load(db)
.await?;
if batch.is_empty() {
break;
}
let loaded = batch.len() as u64;
f(batch);
if loaded < current_batch_size {
break;
}
offset += loaded;
if let Some(remaining_count) = remaining.as_mut() {
*remaining_count = remaining_count.saturating_sub(loaded);
if *remaining_count == 0 {
break;
}
}
}
Ok(())
}
pub fn find_in_batches_sync<F>(&self, batch_size: u64, mut f: F) -> Result<(), RecordError>
where
F: FnMut(Vec<T>),
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.find_in_batches(batch_size, db, &mut f)))
}
pub async fn first(&self, db: &DatabaseConnection) -> Result<Option<T>, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let model = self.build_select()?.limit(1).one(db).await?;
Ok(model.map(|model| {
let mut record = T::from_sea_model(model);
record.set_record_state(RecordState::Persisted);
record
}))
}
pub fn first_sync(&self) -> Result<Option<T>, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.first(db)))
}
pub async fn last(&self, db: &DatabaseConnection) -> Result<Option<T>, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let mut records = self.load(db).await?;
Ok(records.pop())
}
pub fn last_sync(&self) -> Result<Option<T>, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.last(db)))
}
pub async fn count(&self, db: &DatabaseConnection) -> Result<u64, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
<T::Entity as EntityTrait>::Model: FromQueryResult + Send + Sync,
{
self.build_select()?
.all(db)
.await
.map(|models| models.len() as u64)
.map_err(Into::into)
}
pub fn count_sync(&self) -> Result<u64, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
<T::Entity as EntityTrait>::Model: FromQueryResult + Send + Sync,
{
database::with_db(|db| runtime::block_on(self.count(db)))
}
pub async fn exists(&self, db: &DatabaseConnection) -> Result<bool, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
Ok(self.first(db).await?.is_some())
}
pub fn exists_sync(&self) -> Result<bool, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.exists(db)))
}
pub async fn sole(&self, db: &DatabaseConnection) -> Result<T, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let mut records = self.clone().limit(2).load(db).await?;
match records.len() {
0 => Err(RecordError::NotFound),
1 => Ok(records.pop().expect("exactly one row should be loaded")),
_ => Err(RecordError::SoleRecordExceeded),
}
}
pub fn sole_sync(&self) -> Result<T, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.sole(db)))
}
pub async fn explain(&self, db: &DatabaseConnection) -> Result<String, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let backend = db.get_database_backend();
let statement = self.build_select()?.build(backend);
let explain_statement = Statement {
sql: format!("{} {}", explain_prefix(backend), statement.sql),
values: statement.values,
db_backend: statement.db_backend,
};
let rows = db.query_all_raw(explain_statement).await?;
let lines = rows
.into_iter()
.filter_map(|row| {
if backend == DbBackend::Sqlite {
return row.try_get::<String>("", "detail").ok();
}
row.try_get_by_index::<String>(0).ok()
})
.collect::<Vec<_>>();
Ok(lines.join("\n"))
}
pub fn explain_sync(&self) -> Result<String, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.explain(db)))
}
pub async fn sum(&self, column: &str, db: &DatabaseConnection) -> Result<f64, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
self.pluck(column, db)
.await?
.into_iter()
.try_fold(0.0, |total, value| {
Ok(total + json_value_to_f64(&value, column)?.unwrap_or(0.0))
})
}
pub fn sum_sync(&self, column: &str) -> Result<f64, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.sum(column, db)))
}
pub async fn average(&self, column: &str, db: &DatabaseConnection) -> Result<f64, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let (total, count) = self.pluck(column, db).await?.into_iter().try_fold(
(0.0, 0_u64),
|(total, count), value| {
Ok::<(f64, u64), RecordError>(match json_value_to_f64(&value, column)? {
Some(number) => (total + number, count + 1),
None => (total, count),
})
},
)?;
if count == 0 {
return Ok(0.0);
}
Ok(total / count as f64)
}
pub fn average_sync(&self, column: &str) -> Result<f64, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.average(column, db)))
}
pub async fn minimum(
&self,
column: &str,
db: &DatabaseConnection,
) -> Result<Option<JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
relation_extreme(self.pluck(column, db).await?, column, Ordering::Less)
}
pub fn minimum_sync(&self, column: &str) -> Result<Option<JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.minimum(column, db)))
}
pub async fn maximum(
&self,
column: &str,
db: &DatabaseConnection,
) -> Result<Option<JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
relation_extreme(self.pluck(column, db).await?, column, Ordering::Greater)
}
pub fn maximum_sync(&self, column: &str) -> Result<Option<JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.maximum(column, db)))
}
pub async fn group_count(
&self,
db: &DatabaseConnection,
) -> Result<HashMap<JsonValue, JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
self.primary_group_column()?;
let grouped = self.filtered_grouped_records(db).await?;
Ok(grouped
.into_iter()
.map(|(key, records)| (key, json!(records.len() as i64)))
.collect())
}
pub async fn group_sum(
&self,
column: &str,
db: &DatabaseConnection,
) -> Result<HashMap<JsonValue, JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let _group_column = self.primary_group_column()?;
let grouped = self.filtered_grouped_records(db).await?;
grouped
.into_iter()
.map(|(key, records)| Ok((key, aggregate_sum_value(&records, column)?)))
.collect()
}
pub async fn group_average(
&self,
column: &str,
db: &DatabaseConnection,
) -> Result<HashMap<JsonValue, JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let _group_column = self.primary_group_column()?;
let grouped = self.filtered_grouped_records(db).await?;
grouped
.into_iter()
.map(|(key, records)| Ok((key, json!(aggregate_average_value(&records, column)?))))
.collect()
}
pub async fn group_minimum(
&self,
column: &str,
db: &DatabaseConnection,
) -> Result<HashMap<JsonValue, JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let _group_column = self.primary_group_column()?;
let grouped = self.filtered_grouped_records(db).await?;
grouped
.into_iter()
.map(|(key, records)| {
let value = aggregate_extreme_value(&records, column, Ordering::Less)?.ok_or_else(
|| RecordError::Invalid(format!("group `{key}` has no values for `{column}`")),
)?;
Ok((key, value))
})
.collect()
}
pub async fn group_maximum(
&self,
column: &str,
db: &DatabaseConnection,
) -> Result<HashMap<JsonValue, JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let _group_column = self.primary_group_column()?;
let grouped = self.filtered_grouped_records(db).await?;
grouped
.into_iter()
.map(|(key, records)| {
let value = aggregate_extreme_value(&records, column, Ordering::Greater)?
.ok_or_else(|| {
RecordError::Invalid(format!("group `{key}` has no values for `{column}`"))
})?;
Ok((key, value))
})
.collect()
}
pub async fn pluck_columns(
&self,
columns: &[&str],
db: &DatabaseConnection,
) -> Result<Vec<Vec<JsonValue>>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
if columns.is_empty() {
return Err(RecordError::Invalid(
"pluck_columns requires at least one column".to_owned(),
));
}
let records = self.load(db).await?;
records
.iter()
.map(|record| {
columns
.iter()
.map(|column| serialized_record_field(record, column))
.collect()
})
.collect()
}
pub fn pluck_columns_sync(&self, columns: &[&str]) -> Result<Vec<Vec<JsonValue>>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.pluck_columns(columns, db)))
}
pub async fn pluck(
&self,
column: &str,
db: &DatabaseConnection,
) -> Result<Vec<JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
self.load(db)
.await?
.iter()
.map(|record| serialized_record_field(record, column))
.collect()
}
pub fn pluck_sync(&self, column: &str) -> Result<Vec<JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.pluck(column, db)))
}
pub async fn pick(
&self,
column: &str,
db: &DatabaseConnection,
) -> Result<Option<JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
self.first(db)
.await?
.as_ref()
.map(|record| serialized_record_field(record, column))
.transpose()
}
pub fn pick_sync(&self, column: &str) -> Result<Option<JsonValue>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.pick(column, db)))
}
pub async fn ids(&self, db: &DatabaseConnection) -> Result<Vec<i64>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
self.pluck(T::primary_key_name(), db)
.await?
.into_iter()
.map(|value| json_value_to_i64(&value, T::primary_key_name()))
.collect()
}
pub fn ids_sync(&self) -> Result<Vec<i64>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| runtime::block_on(self.ids(db)))
}
fn primary_group_column(&self) -> Result<&str, RecordError> {
self.group_columns
.first()
.map(String::as_str)
.ok_or_else(|| {
RecordError::Invalid(
"grouped calculation requires at least one group column".to_owned(),
)
})
}
async fn grouped_records_for_calculation(
&self,
db: &DatabaseConnection,
) -> Result<HashMap<JsonValue, Vec<JsonValue>>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let group_column = self.primary_group_column()?;
let mut relation = self.clone();
relation.group_columns.clear();
relation.having_conditions.clear();
relation.select_columns.clear();
let mut grouped = HashMap::new();
for record in relation.load(db).await? {
let serialized = serde_json::to_value(&record)
.map_err(|error| RecordError::Invalid(error.to_string()))?;
let key = serialized_record_object_field(&serialized, group_column)?;
grouped.entry(key).or_insert_with(Vec::new).push(serialized);
}
Ok(grouped)
}
async fn filtered_grouped_records(
&self,
db: &DatabaseConnection,
) -> Result<HashMap<JsonValue, Vec<JsonValue>>, RecordError>
where
T: serde::Serialize,
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let group_column = self.primary_group_column()?;
let grouped = self.grouped_records_for_calculation(db).await?;
apply_having_conditions(grouped, group_column, &self.having_conditions)
}
pub(crate) fn condition(&self) -> Result<Condition, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let active_groups = self
.where_groups
.iter()
.filter(|group| !group.is_empty())
.collect::<Vec<_>>();
if active_groups.is_empty() {
return Ok(Condition::all());
}
let mut disjunction = Condition::any();
for group in active_groups {
let mut conjunction = Condition::all();
for (column, value) in &group.positive {
conjunction = conjunction.add(build_filter::<T>(column, value, false)?);
}
for (column, value) in &group.negative {
conjunction = conjunction.add(build_filter::<T>(column, value, true)?);
}
disjunction = disjunction.add(conjunction);
}
Ok(disjunction)
}
fn build_select(&self) -> Result<Select<T::Entity>, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let mut query = T::Entity::find().filter(self.condition()?);
if !self.select_columns.is_empty() {
query = query.select_only();
for column in <T::Entity as EntityTrait>::Column::iter() {
if self
.select_columns
.iter()
.any(|selected| selected == column.as_str())
{
query = query.column(column);
continue;
}
if let Some(default_expr) = default_projection_expr(&column) {
query = query.expr_as(default_expr, column.as_str());
} else {
query = query.column(column);
}
}
}
if self.is_distinct {
query = query.distinct();
}
for column in &self.group_columns {
query = query.group_by(resolve_column::<T>(column)?);
}
for condition in &self.having_conditions {
query = query.having(Expr::cust(condition.clone()));
}
for (column, dir) in &self.order_by {
query = query.order_by(resolve_column::<T>(column)?, sea_order(*dir));
}
if let Some(limit) = self.limit_val {
query = query.limit(limit);
}
if let Some(offset) = self.offset_val {
if self.limit_val.is_none() {
query = query.limit(i64::MAX as u64);
}
query = query.offset(offset);
}
Ok(query)
}
}
#[derive(Debug, Clone, Copy)]
enum HavingComparison {
Eq,
NotEq,
Gt,
Gte,
Lt,
Lte,
}
#[derive(Debug, Clone)]
enum HavingExpression {
GroupKey,
Count,
Sum(String),
Average(String),
Minimum(String),
Maximum(String),
}
#[derive(Debug, Clone)]
struct HavingPredicate {
left: HavingExpression,
comparison: HavingComparison,
right: JsonValue,
}
fn serialized_record_object_field(
record: &JsonValue,
column: &str,
) -> Result<JsonValue, RecordError> {
let object = record
.as_object()
.ok_or_else(|| RecordError::Invalid("record must serialize to a JSON object".to_owned()))?;
object
.get(column)
.cloned()
.ok_or_else(|| RecordError::Invalid(format!("unknown column: {column}")))
}
fn serialized_record_field<T: serde::Serialize>(
record: &T,
column: &str,
) -> Result<JsonValue, RecordError> {
let value =
serde_json::to_value(record).map_err(|error| RecordError::Invalid(error.to_string()))?;
serialized_record_object_field(&value, column)
}
fn json_value_to_f64(value: &JsonValue, column: &str) -> Result<Option<f64>, RecordError> {
match value {
JsonValue::Null => Ok(None),
JsonValue::Number(number) => number.as_f64().map(Some).ok_or_else(|| {
RecordError::Invalid(format!("column `{column}` must contain numeric values"))
}),
_ => Err(RecordError::Invalid(format!(
"column `{column}` must contain numeric values"
))),
}
}
fn json_value_to_i64(value: &JsonValue, column: &str) -> Result<i64, RecordError> {
match value {
JsonValue::Number(number) => {
if let Some(value) = number.as_i64() {
Ok(value)
} else if let Some(value) = number.as_u64() {
i64::try_from(value).map_err(|_| {
RecordError::Invalid(format!("column `{column}` does not fit in i64"))
})
} else {
Err(RecordError::Invalid(format!(
"column `{column}` must contain integer values"
)))
}
}
_ => Err(RecordError::Invalid(format!(
"column `{column}` must contain integer values"
))),
}
}
fn extracted_group_values(
records: &[JsonValue],
column: &str,
) -> Result<Vec<JsonValue>, RecordError> {
records
.iter()
.map(|record| serialized_record_object_field(record, column))
.collect()
}
fn json_value_is_integral(value: &JsonValue) -> bool {
matches!(value, JsonValue::Number(number) if number.as_i64().is_some() || number.as_u64().is_some())
}
fn numeric_json_value(value: f64, prefer_integer: bool) -> JsonValue {
if prefer_integer
&& value.fract() == 0.0
&& value >= i64::MIN as f64
&& value <= i64::MAX as f64
{
json!(value as i64)
} else {
json!(value)
}
}
fn aggregate_sum_value(records: &[JsonValue], column: &str) -> Result<JsonValue, RecordError> {
let values = extracted_group_values(records, column)?;
let mut total = 0.0;
let mut all_integral = true;
for value in values {
if let Some(number) = json_value_to_f64(&value, column)? {
total += number;
all_integral &= json_value_is_integral(&value);
}
}
Ok(numeric_json_value(total, all_integral))
}
fn aggregate_average_value(records: &[JsonValue], column: &str) -> Result<f64, RecordError> {
let values = extracted_group_values(records, column)?;
let mut total = 0.0;
let mut count = 0_u64;
for value in values {
if let Some(number) = json_value_to_f64(&value, column)? {
total += number;
count += 1;
}
}
if count == 0 {
Ok(0.0)
} else {
Ok(total / count as f64)
}
}
fn aggregate_extreme_value(
records: &[JsonValue],
column: &str,
preferred: Ordering,
) -> Result<Option<JsonValue>, RecordError> {
relation_extreme(extracted_group_values(records, column)?, column, preferred)
}
fn apply_having_conditions(
grouped: HashMap<JsonValue, Vec<JsonValue>>,
group_column: &str,
conditions: &[String],
) -> Result<HashMap<JsonValue, Vec<JsonValue>>, RecordError> {
if conditions.is_empty() {
return Ok(grouped);
}
let predicates = conditions
.iter()
.map(|condition| parse_having_predicate(condition, group_column))
.collect::<Result<Vec<_>, _>>()?;
let mut filtered = HashMap::new();
for (key, records) in grouped {
let mut keep = true;
for predicate in &predicates {
let left = evaluate_having_expression(&predicate.left, &key, &records)?;
if !compare_having_values(&left, &predicate.right, predicate.comparison)? {
keep = false;
break;
}
}
if keep {
filtered.insert(key, records);
}
}
Ok(filtered)
}
fn parse_having_predicate(
condition: &str,
group_column: &str,
) -> Result<HavingPredicate, RecordError> {
let trimmed = condition.trim();
let (operator, comparison) = [
(">=", HavingComparison::Gte),
("<=", HavingComparison::Lte),
("!=", HavingComparison::NotEq),
("=", HavingComparison::Eq),
(">", HavingComparison::Gt),
("<", HavingComparison::Lt),
]
.into_iter()
.find(|(operator, _)| trimmed.contains(operator))
.ok_or_else(|| RecordError::Invalid(format!("unsupported HAVING clause: {condition}")))?;
let (left, right) = trimmed
.split_once(operator)
.ok_or_else(|| RecordError::Invalid(format!("unsupported HAVING clause: {condition}")))?;
Ok(HavingPredicate {
left: parse_having_expression(left.trim(), group_column)?,
comparison,
right: parse_having_value(right.trim())?,
})
}
fn parse_having_expression(
expression: &str,
group_column: &str,
) -> Result<HavingExpression, RecordError> {
let normalized = expression.trim();
if normalized.eq_ignore_ascii_case(group_column)
|| normalized
.rsplit('.')
.next()
.is_some_and(|column| column.eq_ignore_ascii_case(group_column))
{
return Ok(HavingExpression::GroupKey);
}
let open = normalized.find('(').ok_or_else(|| {
RecordError::Invalid(format!("unsupported HAVING expression: {expression}"))
})?;
let close = normalized.rfind(')').ok_or_else(|| {
RecordError::Invalid(format!("unsupported HAVING expression: {expression}"))
})?;
let function = normalized[..open].trim().to_ascii_uppercase();
let argument = normalized[open + 1..close].trim();
match function.as_str() {
"COUNT" if argument == "*" => Ok(HavingExpression::Count),
"SUM" => Ok(HavingExpression::Sum(argument.to_owned())),
"AVG" | "AVERAGE" => Ok(HavingExpression::Average(argument.to_owned())),
"MIN" | "MINIMUM" => Ok(HavingExpression::Minimum(argument.to_owned())),
"MAX" | "MAXIMUM" => Ok(HavingExpression::Maximum(argument.to_owned())),
_ => Err(RecordError::Invalid(format!(
"unsupported HAVING expression: {expression}"
))),
}
}
fn parse_having_value(value: &str) -> Result<JsonValue, RecordError> {
if (value.starts_with('\'') && value.ends_with('\''))
|| (value.starts_with('"') && value.ends_with('"'))
{
return Ok(json!(value[1..value.len() - 1].to_owned()));
}
if value.eq_ignore_ascii_case("true") {
return Ok(json!(true));
}
if value.eq_ignore_ascii_case("false") {
return Ok(json!(false));
}
if let Ok(number) = value.parse::<i64>() {
return Ok(json!(number));
}
if let Ok(number) = value.parse::<f64>() {
return Ok(json!(number));
}
Err(RecordError::Invalid(format!(
"unsupported HAVING value: {value}"
)))
}
fn evaluate_having_expression(
expression: &HavingExpression,
key: &JsonValue,
records: &[JsonValue],
) -> Result<JsonValue, RecordError> {
match expression {
HavingExpression::GroupKey => Ok(key.clone()),
HavingExpression::Count => Ok(json!(records.len() as i64)),
HavingExpression::Sum(column) => aggregate_sum_value(records, column),
HavingExpression::Average(column) => Ok(json!(aggregate_average_value(records, column)?)),
HavingExpression::Minimum(column) => {
aggregate_extreme_value(records, column, Ordering::Less)?
.ok_or_else(|| RecordError::Invalid(format!("group has no values for `{column}`")))
}
HavingExpression::Maximum(column) => {
aggregate_extreme_value(records, column, Ordering::Greater)?
.ok_or_else(|| RecordError::Invalid(format!("group has no values for `{column}`")))
}
}
}
fn compare_having_values(
left: &JsonValue,
right: &JsonValue,
comparison: HavingComparison,
) -> Result<bool, RecordError> {
match comparison {
HavingComparison::Eq => Ok(left == right),
HavingComparison::NotEq => Ok(left != right),
HavingComparison::Gt => {
Ok(compare_json_values(left, right, "HAVING")? == Ordering::Greater)
}
HavingComparison::Gte => Ok(matches!(
compare_json_values(left, right, "HAVING")?,
Ordering::Greater | Ordering::Equal
)),
HavingComparison::Lt => Ok(compare_json_values(left, right, "HAVING")? == Ordering::Less),
HavingComparison::Lte => Ok(matches!(
compare_json_values(left, right, "HAVING")?,
Ordering::Less | Ordering::Equal
)),
}
}
fn default_projection_expr<C>(column: &C) -> Option<SimpleExpr>
where
C: ColumnTrait,
{
Some(match column.def().get_column_type() {
ColumnType::Char(_)
| ColumnType::String(_)
| ColumnType::Text
| ColumnType::Custom(_)
| ColumnType::Enum { .. } => Expr::val(""),
ColumnType::TinyInteger
| ColumnType::SmallInteger
| ColumnType::Integer
| ColumnType::BigInteger
| ColumnType::TinyUnsigned
| ColumnType::SmallUnsigned
| ColumnType::Unsigned
| ColumnType::BigUnsigned
| ColumnType::Year => Expr::val(0),
ColumnType::Float | ColumnType::Double | ColumnType::Decimal(_) | ColumnType::Money(_) => {
Expr::val(0.0)
}
ColumnType::Boolean => Expr::val(false),
_ => return None,
})
}
pub(crate) fn resolve_column<T: Record>(
name: &str,
) -> Result<<T::Entity as EntityTrait>::Column, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
<<T::Entity as EntityTrait>::Column as FromStr>::from_str(name)
.map_err(|_| RecordError::Invalid(format!("unknown column: {name}")))
}
pub(crate) fn json_to_sea_value(value: &JsonValue) -> Result<Value, RecordError> {
match value {
JsonValue::Null => Err(RecordError::Invalid(
"null values are not supported in this context".to_owned(),
)),
JsonValue::Bool(flag) => Ok((*flag).into()),
JsonValue::Number(number) => {
if let Some(value) = number.as_i64() {
Ok(value.into())
} else if let Some(value) = number.as_u64() {
Ok(value.into())
} else if let Some(value) = number.as_f64() {
Ok(value.into())
} else {
Err(RecordError::Invalid(format!(
"unsupported numeric value: {number}"
)))
}
}
JsonValue::String(text) => Ok(text.clone().into()),
JsonValue::Array(_) | JsonValue::Object(_) => Err(RecordError::Invalid(
"only scalar JSON values are supported in query conditions".to_owned(),
)),
}
}
fn relation_extreme(
values: Vec<JsonValue>,
column: &str,
preferred: Ordering,
) -> Result<Option<JsonValue>, RecordError> {
let mut values = values.into_iter().filter(|value| !value.is_null());
let Some(mut extreme) = values.next() else {
return Ok(None);
};
for value in values {
if compare_json_values(&value, &extreme, column)? == preferred {
extreme = value;
}
}
Ok(Some(extreme))
}
fn compare_json_values(
left: &JsonValue,
right: &JsonValue,
column: &str,
) -> Result<Ordering, RecordError> {
match (left, right) {
(JsonValue::Number(left), JsonValue::Number(right)) => left
.as_f64()
.zip(right.as_f64())
.and_then(|(left, right)| left.partial_cmp(&right))
.ok_or_else(|| {
RecordError::Invalid(format!(
"column `{column}` must contain comparable numeric values"
))
}),
(JsonValue::String(left), JsonValue::String(right)) => Ok(left.cmp(right)),
(JsonValue::Bool(left), JsonValue::Bool(right)) => Ok(left.cmp(right)),
(JsonValue::Null, JsonValue::Null) => Ok(Ordering::Equal),
_ => Err(RecordError::Invalid(format!(
"column `{column}` must contain comparable scalar values"
))),
}
}
fn sea_order(dir: OrderDirection) -> Order {
match dir {
OrderDirection::Asc => Order::Asc,
OrderDirection::Desc => Order::Desc,
}
}
fn explain_prefix(backend: DbBackend) -> &'static str {
match backend {
DbBackend::Sqlite => "EXPLAIN QUERY PLAN",
DbBackend::MySql | DbBackend::Postgres => "EXPLAIN",
_ => "EXPLAIN",
}
}
fn build_filter<T: Record>(
column_name: &str,
value: &JsonValue,
negated: bool,
) -> Result<SimpleExpr, RecordError>
where
<T::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let column = resolve_column::<T>(column_name)?;
let expr = Expr::col(column);
let filter = match value {
JsonValue::Null => {
if negated {
expr.is_not_null()
} else {
expr.is_null()
}
}
_ => {
let value = json_to_sea_value(value)?;
if negated {
expr.ne(value)
} else {
expr.eq(value)
}
}
};
Ok(filter)
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use rustrails_support::{database, runtime};
use sea_orm::{ActiveModelTrait, ActiveValue::Set, ConnectionTrait, Schema};
use serde_json::{Value, json};
use super::Relation;
use crate::{
OrderDirection, RecordState,
base::test_support::{TestUser, seed_users, test_user},
};
fn run_relation_test(seed: bool, test: impl FnOnce() + Send + 'static) {
std::thread::spawn(move || {
let _rt = runtime::init_runtime();
database::establish("sqlite::memory:")
.expect("sqlite in-memory connection should succeed");
runtime::block_on(async {
let db = database::db();
let schema = Schema::new(db.get_database_backend());
db.execute(&schema.create_table_from_entity(test_user::Entity))
.await
.expect("test_users table should be created");
if seed {
seed_users(&db).await;
}
});
test();
})
.join()
.unwrap();
}
fn run_seeded_relation_test(test: impl FnOnce() + Send + 'static) {
run_relation_test(true, test);
}
fn run_empty_relation_test(test: impl FnOnce() + Send + 'static) {
run_relation_test(false, test);
}
fn load_relation(relation: Relation<TestUser>) -> Result<Vec<TestUser>, crate::RecordError> {
relation.load_sync()
}
fn first_relation(
relation: Relation<TestUser>,
) -> Result<Option<TestUser>, crate::RecordError> {
relation.first_sync()
}
fn last_relation(relation: Relation<TestUser>) -> Result<Option<TestUser>, crate::RecordError> {
relation.last_sync()
}
fn count_relation(relation: Relation<TestUser>) -> Result<u64, crate::RecordError> {
relation.count_sync()
}
fn exists_relation(relation: Relation<TestUser>) -> Result<bool, crate::RecordError> {
relation.exists_sync()
}
fn explain_relation(relation: Relation<TestUser>) -> Result<String, crate::RecordError> {
relation.explain_sync()
}
fn sum_relation(relation: Relation<TestUser>, column: &str) -> Result<f64, crate::RecordError> {
relation.sum_sync(column)
}
fn average_relation(
relation: Relation<TestUser>,
column: &str,
) -> Result<f64, crate::RecordError> {
relation.average_sync(column)
}
fn minimum_relation(
relation: Relation<TestUser>,
column: &str,
) -> Result<Option<Value>, crate::RecordError> {
relation.minimum_sync(column)
}
fn maximum_relation(
relation: Relation<TestUser>,
column: &str,
) -> Result<Option<Value>, crate::RecordError> {
relation.maximum_sync(column)
}
fn pluck_relation(
relation: Relation<TestUser>,
column: &str,
) -> Result<Vec<Value>, crate::RecordError> {
relation.pluck_sync(column)
}
fn pick_relation(
relation: Relation<TestUser>,
column: &str,
) -> Result<Option<Value>, crate::RecordError> {
relation.pick_sync(column)
}
fn relation_ids(relation: Relation<TestUser>) -> Result<Vec<i64>, crate::RecordError> {
relation.ids_sync()
}
fn sole_relation(relation: Relation<TestUser>) -> Result<TestUser, crate::RecordError> {
relation.sole_sync()
}
fn pluck_columns_relation(
relation: Relation<TestUser>,
columns: &[&str],
) -> Result<Vec<Vec<Value>>, crate::RecordError> {
relation.pluck_columns_sync(columns)
}
fn group_count_relation(
relation: Relation<TestUser>,
) -> Result<HashMap<Value, Value>, crate::RecordError> {
database::with_db(|db| runtime::block_on(relation.group_count(db)))
}
fn group_sum_relation(
relation: Relation<TestUser>,
column: &str,
) -> Result<HashMap<Value, Value>, crate::RecordError> {
database::with_db(|db| runtime::block_on(relation.group_sum(column, db)))
}
fn group_average_relation(
relation: Relation<TestUser>,
column: &str,
) -> Result<HashMap<Value, Value>, crate::RecordError> {
database::with_db(|db| runtime::block_on(relation.group_average(column, db)))
}
fn group_minimum_relation(
relation: Relation<TestUser>,
column: &str,
) -> Result<HashMap<Value, Value>, crate::RecordError> {
database::with_db(|db| runtime::block_on(relation.group_minimum(column, db)))
}
fn group_maximum_relation(
relation: Relation<TestUser>,
column: &str,
) -> Result<HashMap<Value, Value>, crate::RecordError> {
database::with_db(|db| runtime::block_on(relation.group_maximum(column, db)))
}
fn insert_user(name: &str, email: &str) {
database::with_db(|db| {
runtime::block_on(async {
test_user::ActiveModel {
name: Set(name.to_owned()),
email: Set(email.to_owned()),
..Default::default()
}
.insert(db)
.await
.expect("fixture insert should succeed");
});
});
}
fn relation_names(relation: Relation<TestUser>) -> Vec<String> {
load_relation(relation)
.expect("relation should load")
.into_iter()
.map(|user| user.name)
.collect()
}
fn relation_name_email_pairs(relation: Relation<TestUser>) -> Vec<(String, String)> {
load_relation(relation)
.expect("relation should load")
.into_iter()
.map(|user| (user.name, user.email))
.collect()
}
fn find_each_names(
relation: Relation<TestUser>,
batch_size: u64,
) -> Result<Vec<String>, crate::RecordError> {
let mut names = Vec::new();
relation.find_each_sync(batch_size, |user| names.push(user.name))?;
Ok(names)
}
fn find_batch_names(
relation: Relation<TestUser>,
batch_size: u64,
) -> Result<Vec<Vec<String>>, crate::RecordError> {
let mut batches = Vec::new();
relation.find_in_batches_sync(batch_size, |users| {
batches.push(users.into_iter().map(|user| user.name).collect());
})?;
Ok(batches)
}
#[test]
fn relation_load_returns_all_rows_when_scope_empty() {
run_seeded_relation_test(|| {
let users = load_relation(Relation::<TestUser>::new()).expect("relation should load");
assert_eq!(users.len(), 3);
assert_eq!(users[0].name, "Alice");
assert_eq!(users[2].name, "Carol");
});
}
#[test]
fn relation_load_returns_empty_vec_when_table_empty() {
run_empty_relation_test(|| {
let users = load_relation(Relation::<TestUser>::new()).expect("relation should load");
assert!(users.is_empty());
});
}
#[test]
fn relation_where_matches_single_name_condition() {
run_seeded_relation_test(|| {
let users = load_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Bob"))])),
)
.expect("relation should load");
assert_eq!(users.len(), 1);
assert_eq!(users[0].email, "bob@example.com");
});
}
#[test]
fn relation_where_matches_single_email_condition() {
run_seeded_relation_test(|| {
let users = load_relation(Relation::<TestUser>::new().r#where(HashMap::from([(
"email".to_owned(),
json!("carol@example.com"),
)])))
.expect("relation should load");
assert_eq!(users.len(), 1);
assert_eq!(users[0].name, "Carol");
});
}
#[test]
fn relation_where_with_multiple_conditions_requires_all_matches() {
run_seeded_relation_test(|| {
let users = load_relation(Relation::<TestUser>::new().r#where(HashMap::from([
("name".to_owned(), json!("Bob")),
("email".to_owned(), json!("bob@example.com")),
])))
.expect("relation should load");
assert_eq!(users.len(), 1);
assert_eq!(users[0].id, Some(2));
});
}
#[test]
fn relation_where_with_multiple_conditions_returns_no_partial_matches() {
run_seeded_relation_test(|| {
let users = load_relation(Relation::<TestUser>::new().r#where(HashMap::from([
("name".to_owned(), json!("Bob")),
("email".to_owned(), json!("alice@example.com")),
])))
.expect("relation should load");
assert!(users.is_empty());
});
}
#[test]
fn relation_order_ascending_returns_expected_sequence() {
run_seeded_relation_test(|| {
assert_eq!(
relation_names(Relation::<TestUser>::new().order("id", OrderDirection::Asc)),
vec!["Alice", "Bob", "Carol"],
);
});
}
#[test]
fn relation_order_descending_returns_expected_sequence() {
run_seeded_relation_test(|| {
assert_eq!(
relation_names(Relation::<TestUser>::new().order("id", OrderDirection::Desc)),
vec!["Carol", "Bob", "Alice"],
);
});
}
#[test]
fn relation_reorder_replaces_existing_sort_clauses() {
run_seeded_relation_test(|| {
assert_eq!(
relation_name_email_pairs(
Relation::<TestUser>::new()
.order("name", OrderDirection::Asc)
.reorder("id", OrderDirection::Desc),
),
vec![
("Carol".to_owned(), "carol@example.com".to_owned()),
("Bob".to_owned(), "bob@example.com".to_owned()),
("Alice".to_owned(), "alice@example.com".to_owned()),
],
);
});
}
#[test]
fn relation_where_order_limit_offset_can_be_chained() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
let users = load_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Bob"))]))
.order("id", OrderDirection::Desc)
.offset(1)
.limit(1),
)
.expect("relation should load");
assert_eq!(users.len(), 1);
assert_eq!(users[0].email, "bob@example.com");
});
}
#[test]
fn relation_where_order_distinct_limit_can_be_chained() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
let users = load_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Bob"))]))
.order("id", OrderDirection::Desc)
.distinct()
.limit(5),
)
.expect("relation should load");
assert_eq!(users.len(), 2);
assert_eq!(users[0].email, "bobby@example.com");
assert_eq!(users[1].email, "bob@example.com");
});
}
#[test]
fn relation_find_in_batches_yields_two_batches_for_three_rows() {
run_seeded_relation_test(|| {
let batches = find_batch_names(
Relation::<TestUser>::new().order("id", OrderDirection::Asc),
2,
)
.expect("batch query should succeed");
assert_eq!(
batches,
vec![
vec!["Alice".to_owned(), "Bob".to_owned()],
vec!["Carol".to_owned()],
]
);
});
}
#[test]
fn relation_find_each_batch_processes_all_rows() {
run_seeded_relation_test(|| {
let names = find_each_names(
Relation::<TestUser>::new().order("id", OrderDirection::Asc),
2,
)
.expect("each query should succeed");
assert_eq!(names, vec!["Alice", "Bob", "Carol"]);
});
}
#[test]
fn relation_find_in_batches_yields_no_batches_for_empty_table() {
run_empty_relation_test(|| {
let batches = find_batch_names(
Relation::<TestUser>::new().order("id", OrderDirection::Asc),
2,
)
.expect("batch query should succeed");
assert!(batches.is_empty());
});
}
#[test]
fn relation_find_in_batches_respects_existing_limit_and_offset_scope() {
run_seeded_relation_test(|| {
let batches = find_batch_names(
Relation::<TestUser>::new()
.order("id", OrderDirection::Asc)
.offset(1)
.limit(2),
1,
)
.expect("batch query should succeed");
assert_eq!(
batches,
vec![vec!["Bob".to_owned()], vec!["Carol".to_owned()]],
);
});
}
#[test]
fn relation_find_in_batches_rejects_zero_batch_size() {
run_seeded_relation_test(|| {
let error = find_batch_names(
Relation::<TestUser>::new().order("id", OrderDirection::Asc),
0,
)
.expect_err("zero batch size should fail");
assert!(matches!(error, crate::RecordError::Invalid(_)));
});
}
#[test]
fn relation_count_returns_total_for_unscoped_relation() {
run_seeded_relation_test(|| {
assert_eq!(
count_relation(Relation::<TestUser>::new()).expect("count should succeed"),
3
);
});
}
#[test]
fn relation_count_returns_matches_for_scoped_relation() {
run_seeded_relation_test(|| {
assert_eq!(
count_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Alice"))])),
)
.expect("count should succeed"),
1,
);
});
}
#[test]
fn relation_count_returns_zero_for_empty_scope() {
run_seeded_relation_test(|| {
assert_eq!(
count_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Nobody"))])),
)
.expect("count should succeed"),
0,
);
});
}
#[test]
fn relation_count_respects_limit_and_offset_scope() {
run_seeded_relation_test(|| {
assert_eq!(
count_relation(
Relation::<TestUser>::new()
.order("id", OrderDirection::Asc)
.offset(1)
.limit(1),
)
.expect("count should succeed"),
1,
);
});
}
#[test]
fn relation_grouped_count_counts_groups() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
assert_eq!(
count_relation(Relation::<TestUser>::new().group("name"))
.expect("count should succeed"),
3,
);
});
}
#[test]
fn relation_first_returns_first_row_for_ordered_relation() {
run_seeded_relation_test(|| {
let user =
first_relation(Relation::<TestUser>::new().order("id", OrderDirection::Desc))
.expect("query should succeed")
.expect("row should exist");
assert_eq!(user.name, "Carol");
});
}
#[test]
fn relation_first_respects_offset_and_limit() {
run_seeded_relation_test(|| {
let user = first_relation(
Relation::<TestUser>::new()
.order("id", OrderDirection::Asc)
.offset(1)
.limit(1),
)
.expect("query should succeed")
.expect("row should exist");
assert_eq!(user.name, "Bob");
});
}
#[test]
fn relation_first_returns_none_for_empty_scope() {
run_seeded_relation_test(|| {
let user = first_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Nobody"))])),
)
.expect("query should succeed");
assert!(user.is_none());
});
}
#[test]
fn relation_last_returns_last_row_for_ordered_relation() {
run_seeded_relation_test(|| {
let user = last_relation(Relation::<TestUser>::new().order("id", OrderDirection::Asc))
.expect("query should succeed")
.expect("row should exist");
assert_eq!(user.name, "Carol");
});
}
#[test]
fn relation_last_respects_offset_and_limit() {
run_seeded_relation_test(|| {
let user = last_relation(
Relation::<TestUser>::new()
.order("id", OrderDirection::Asc)
.offset(1)
.limit(2),
)
.expect("query should succeed")
.expect("row should exist");
assert_eq!(user.name, "Carol");
});
}
#[test]
fn relation_last_returns_none_for_empty_scope() {
run_seeded_relation_test(|| {
let user = last_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Nobody"))])),
)
.expect("query should succeed");
assert!(user.is_none());
});
}
#[test]
fn relation_exists_returns_true_for_matching_scope() {
run_seeded_relation_test(|| {
assert!(
exists_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Bob"))])),
)
.expect("exists should succeed"),
);
});
}
#[test]
fn relation_exists_returns_false_for_empty_scope() {
run_seeded_relation_test(|| {
assert!(
!exists_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Nobody"))])),
)
.expect("exists should succeed"),
);
});
}
#[test]
fn relation_explain_returns_non_empty_plan_for_simple_query() {
run_seeded_relation_test(|| {
let plan =
explain_relation(Relation::<TestUser>::new()).expect("explain should succeed");
assert!(!plan.trim().is_empty());
});
}
#[test]
fn relation_explain_returns_non_empty_plan_for_filtered_query() {
run_seeded_relation_test(|| {
let plan = explain_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Bob"))])),
)
.expect("explain should succeed");
assert!(!plan.trim().is_empty());
});
}
#[test]
fn relation_aggregation_sum_returns_total_for_id_column() {
run_seeded_relation_test(|| {
assert_eq!(
sum_relation(Relation::<TestUser>::new(), "id").expect("sum should succeed"),
6.0,
);
});
}
#[test]
fn relation_aggregation_average_returns_mean_for_id_column() {
run_seeded_relation_test(|| {
assert_eq!(
average_relation(Relation::<TestUser>::new(), "id")
.expect("average should succeed"),
2.0,
);
});
}
#[test]
fn relation_aggregation_minimum_returns_first_id() {
run_seeded_relation_test(|| {
assert_eq!(
minimum_relation(
Relation::<TestUser>::new().order("id", OrderDirection::Asc),
"id"
)
.expect("minimum should succeed"),
Some(json!(1)),
);
});
}
#[test]
fn relation_aggregation_maximum_returns_last_id() {
run_seeded_relation_test(|| {
assert_eq!(
maximum_relation(
Relation::<TestUser>::new().order("id", OrderDirection::Asc),
"id"
)
.expect("maximum should succeed"),
Some(json!(3)),
);
});
}
#[test]
fn relation_aggregation_pluck_returns_names() {
run_seeded_relation_test(|| {
assert_eq!(
pluck_relation(
Relation::<TestUser>::new().order("id", OrderDirection::Asc),
"name"
)
.expect("pluck should succeed"),
vec![json!("Alice"), json!("Bob"), json!("Carol")],
);
});
}
#[test]
fn relation_aggregation_pick_returns_first_name() {
run_seeded_relation_test(|| {
assert_eq!(
pick_relation(
Relation::<TestUser>::new().order("id", OrderDirection::Asc),
"name"
)
.expect("pick should succeed"),
Some(json!("Alice")),
);
});
}
#[test]
fn relation_aggregation_ids_returns_primary_keys() {
run_seeded_relation_test(|| {
assert_eq!(
relation_ids(Relation::<TestUser>::new().order("id", OrderDirection::Asc))
.expect("ids should succeed"),
vec![1, 2, 3],
);
});
}
#[test]
fn relation_sole_returns_matching_row_for_single_result_scope() {
run_seeded_relation_test(|| {
let user = sole_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Alice"))])),
)
.expect("sole should succeed");
assert_eq!(user.email, "alice@example.com");
});
}
#[test]
fn relation_sole_returns_not_found_for_empty_scope() {
run_seeded_relation_test(|| {
let error = sole_relation(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Nobody"))])),
)
.expect_err("empty scopes should fail");
assert!(matches!(error, crate::RecordError::NotFound));
});
}
#[test]
fn relation_sole_returns_exceeded_for_multiple_rows() {
run_seeded_relation_test(|| {
let error = sole_relation(Relation::<TestUser>::new())
.expect_err("multiple rows should fail sole");
assert!(matches!(error, crate::RecordError::SoleRecordExceeded));
});
}
#[test]
fn relation_pluck_columns_returns_requested_values() {
run_seeded_relation_test(|| {
assert_eq!(
pluck_columns_relation(
Relation::<TestUser>::new().order("id", OrderDirection::Asc),
&["id", "name"],
)
.expect("pluck_columns should succeed"),
vec![
vec![json!(1), json!("Alice")],
vec![json!(2), json!("Bob")],
vec![json!(3), json!("Carol")],
],
);
});
}
#[test]
fn relation_pluck_columns_returns_empty_for_empty_relation() {
run_empty_relation_test(|| {
assert!(
pluck_columns_relation(Relation::<TestUser>::new(), &["id", "name"])
.expect("pluck_columns should succeed")
.is_empty()
);
});
}
#[test]
fn relation_select_columns_alias_populates_requested_columns() {
run_seeded_relation_test(|| {
let users = load_relation(
Relation::<TestUser>::new()
.select_columns(&["name"])
.order("id", OrderDirection::Asc),
)
.expect("relation should load");
assert_eq!(users.len(), 3);
assert_eq!(users[0].name, "Alice");
assert_eq!(users[0].state, RecordState::Persisted);
});
}
#[test]
fn relation_pluck_returns_default_for_unselected_column() {
run_seeded_relation_test(|| {
assert_eq!(
pluck_relation(
Relation::<TestUser>::new()
.select_columns(&["name"])
.order("id", OrderDirection::Asc),
"id",
)
.expect("pluck should succeed"),
vec![json!(0), json!(0), json!(0)],
);
});
}
#[test]
fn relation_pick_returns_default_for_unselected_column() {
run_seeded_relation_test(|| {
assert_eq!(
pick_relation(
Relation::<TestUser>::new()
.select_columns(&["name"])
.order("id", OrderDirection::Asc),
"id",
)
.expect("pick should succeed"),
Some(json!(0)),
);
});
}
#[test]
fn relation_group_count_groups_rows_by_key() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
let grouped = group_count_relation(Relation::<TestUser>::new().group("name"))
.expect("group_count should succeed");
assert_eq!(grouped.get(&json!("Alice")), Some(&json!(1)));
assert_eq!(grouped.get(&json!("Bob")), Some(&json!(2)));
assert_eq!(grouped.get(&json!("Carol")), Some(&json!(1)));
});
}
#[test]
fn relation_group_count_returns_empty_map_for_empty_scope() {
run_seeded_relation_test(|| {
let grouped = group_count_relation(
Relation::<TestUser>::new()
.group("name")
.r#where(HashMap::from([("name".to_owned(), json!("Nobody"))])),
)
.expect("group_count should succeed");
assert!(grouped.is_empty());
});
}
#[test]
fn relation_group_count_applies_having_filters() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
let grouped = group_count_relation(
Relation::<TestUser>::new()
.group("name")
.having("COUNT(*) > 1"),
)
.expect("group_count should succeed");
assert_eq!(grouped.len(), 1);
assert_eq!(grouped.get(&json!("Bob")), Some(&json!(2)));
});
}
#[test]
fn relation_group_sum_returns_grouped_totals() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
let grouped = group_sum_relation(Relation::<TestUser>::new().group("name"), "id")
.expect("group_sum should succeed");
assert_eq!(grouped.get(&json!("Alice")), Some(&json!(1)));
assert_eq!(grouped.get(&json!("Bob")), Some(&json!(6)));
assert_eq!(grouped.get(&json!("Carol")), Some(&json!(3)));
});
}
#[test]
fn relation_group_average_returns_grouped_means() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
let grouped = group_average_relation(Relation::<TestUser>::new().group("name"), "id")
.expect("group_average should succeed");
assert_eq!(grouped.get(&json!("Alice")), Some(&json!(1.0)));
assert_eq!(grouped.get(&json!("Bob")), Some(&json!(3.0)));
assert_eq!(grouped.get(&json!("Carol")), Some(&json!(3.0)));
});
}
#[test]
fn relation_group_minimum_returns_grouped_extremes() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
let grouped = group_minimum_relation(Relation::<TestUser>::new().group("name"), "id")
.expect("group_minimum should succeed");
assert_eq!(grouped.get(&json!("Alice")), Some(&json!(1)));
assert_eq!(grouped.get(&json!("Bob")), Some(&json!(2)));
assert_eq!(grouped.get(&json!("Carol")), Some(&json!(3)));
});
}
#[test]
fn relation_group_maximum_returns_grouped_extremes() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
let grouped = group_maximum_relation(Relation::<TestUser>::new().group("name"), "id")
.expect("group_maximum should succeed");
assert_eq!(grouped.get(&json!("Alice")), Some(&json!(1)));
assert_eq!(grouped.get(&json!("Bob")), Some(&json!(4)));
assert_eq!(grouped.get(&json!("Carol")), Some(&json!(3)));
});
}
#[test]
fn relation_rewhere_replaces_existing_scope() {
run_seeded_relation_test(|| {
assert_eq!(
relation_names(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Alice"))]))
.rewhere(HashMap::from([("name".to_owned(), json!("Bob"))]))
.order("id", OrderDirection::Asc),
),
vec!["Bob"],
);
});
}
#[test]
fn relation_or_where_combines_scopes() {
run_seeded_relation_test(|| {
assert_eq!(
relation_names(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Alice"))]))
.or_where(HashMap::from([("name".to_owned(), json!("Bob"))]))
.order("id", OrderDirection::Asc),
),
vec!["Alice", "Bob"],
);
});
}
#[test]
fn relation_or_where_then_not_where_applies_negation_to_each_branch() {
run_seeded_relation_test(|| {
assert_eq!(
relation_names(
Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Alice"))]))
.or_where(HashMap::from([("name".to_owned(), json!("Bob"))]))
.not_where(HashMap::from([(
"email".to_owned(),
json!("bob@example.com"),
)]))
.order("id", OrderDirection::Asc),
),
vec!["Alice"],
);
});
}
#[test]
fn relation_not_where_excludes_matching_rows() {
run_seeded_relation_test(|| {
assert_eq!(
relation_names(
Relation::<TestUser>::new()
.not_where(HashMap::from([("name".to_owned(), json!("Bob"))]))
.order("id", OrderDirection::Asc),
),
vec!["Alice", "Carol"],
);
});
}
#[test]
fn relation_distinct_count_counts_unique_selected_rows() {
run_seeded_relation_test(|| {
insert_user("Bob", "bobby@example.com");
assert_eq!(
count_relation(
Relation::<TestUser>::new()
.select_columns(&["name"])
.distinct(),
)
.expect("count should succeed"),
3,
);
});
}
#[test]
fn relation_aggregation_empty_relation_returns_zero_sum_and_empty_pluck() {
run_empty_relation_test(|| {
assert_eq!(
sum_relation(Relation::<TestUser>::new(), "id").expect("sum should succeed"),
0.0,
);
assert!(
pluck_relation(Relation::<TestUser>::new(), "name")
.expect("pluck should succeed")
.is_empty()
);
});
}
#[test]
fn default_relation_starts_empty_and_unscoped() {
let relation = Relation::<TestUser>::default();
assert_eq!(relation.where_groups.len(), 1);
assert!(relation.where_groups[0].is_empty());
assert!(relation.order_by.is_empty());
assert!(relation.group_columns.is_empty());
assert!(relation.having_conditions.is_empty());
assert!(!relation.is_distinct);
assert!(relation.select_columns.is_empty());
assert!(relation.join_associations.is_empty());
assert!(relation.included_associations.is_empty());
assert!(relation.limit_val.is_none());
assert!(relation.offset_val.is_none());
}
#[test]
fn where_accumulates_conditions() {
let relation = Relation::<TestUser>::new()
.r#where(HashMap::from([("name".to_owned(), json!("Alice"))]))
.r#where(HashMap::from([(
"email".to_owned(),
json!("alice@example.com"),
)]));
assert_eq!(relation.where_groups.len(), 1);
assert_eq!(relation.where_groups[0].positive.len(), 2);
assert!(
relation.where_groups[0]
.positive
.iter()
.any(|(column, value)| column == "name" && value == &json!("Alice"))
);
assert!(
relation.where_groups[0]
.positive
.iter()
.any(|(column, value)| column == "email" && value == &json!("alice@example.com"))
);
}
#[test]
fn not_accumulates_negated_conditions() {
let relation = Relation::<TestUser>::new()
.not(HashMap::from([("name".to_owned(), json!("Alice"))]))
.not(HashMap::from([(
"email".to_owned(),
json!("alice@example.com"),
)]));
assert_eq!(relation.where_groups.len(), 1);
assert_eq!(relation.where_groups[0].negative.len(), 2);
assert!(
relation.where_groups[0]
.negative
.iter()
.any(|(column, value)| column == "name" && value == &json!("Alice"))
);
assert!(
relation.where_groups[0]
.negative
.iter()
.any(|(column, value)| column == "email" && value == &json!("alice@example.com"))
);
}
}