use crate::validate::assert_valid_sql_identifier;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum Operator {
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
In,
NotIn,
Regex,
Like,
ILike,
StartsWith,
EndsWith,
Contains,
Between,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum LogicalOp {
And,
Or,
Not,
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum FilterExpr {
Simple(Filter),
Compound(CompoundFilter),
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct CompoundFilter {
pub op: LogicalOp,
pub filters: Vec<FilterExpr>,
}
impl CompoundFilter {
#[must_use]
pub const fn and(filters: Vec<FilterExpr>) -> Self {
Self {
op: LogicalOp::And,
filters,
}
}
#[must_use]
pub const fn or(filters: Vec<FilterExpr>) -> Self {
Self {
op: LogicalOp::Or,
filters,
}
}
#[must_use]
pub fn not(filter: FilterExpr) -> Self {
Self {
op: LogicalOp::Not,
filters: vec![filter],
}
}
}
impl FilterExpr {
#[must_use]
pub fn collect_filters(&self) -> Vec<Filter> {
let mut result = Vec::new();
self.collect_filters_into(&mut result);
result
}
fn collect_filters_into(&self, result: &mut Vec<Filter>) {
match self {
Self::Simple(f) => result.push(f.clone()),
Self::Compound(c) => {
for expr in &c.filters {
expr.collect_filters_into(result);
}
},
}
}
#[must_use]
pub fn iter(&self) -> FilterExprIter {
self.into_iter()
}
}
#[derive(Debug)]
pub struct FilterExprIter {
filters: std::vec::IntoIter<Filter>,
}
impl Iterator for FilterExprIter {
type Item = Filter;
fn next(&mut self) -> Option<Self::Item> {
self.filters.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.filters.size_hint()
}
}
impl IntoIterator for &FilterExpr {
type Item = Filter;
type IntoIter = FilterExprIter;
fn into_iter(self) -> Self::IntoIter {
FilterExprIter {
filters: self.collect_filters().into_iter(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum AggregateFunc {
Count,
CountDistinct,
Sum,
Avg,
Min,
Max,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct Aggregate {
pub func: AggregateFunc,
pub field: Option<String>,
pub alias: Option<String>,
}
impl Aggregate {
#[must_use]
pub fn count() -> Self {
Self {
func: AggregateFunc::Count,
field: None,
alias: Some("count".to_string()),
}
}
pub fn count_field(field: impl Into<String>) -> Self {
let field = field.into();
assert_valid_sql_identifier(&field, "aggregate field");
Self {
func: AggregateFunc::Count,
field: Some(field),
alias: None,
}
}
pub fn count_distinct(field: impl Into<String>) -> Self {
let field = field.into();
assert_valid_sql_identifier(&field, "aggregate field");
Self {
func: AggregateFunc::CountDistinct,
field: Some(field),
alias: None,
}
}
pub fn sum(field: impl Into<String>) -> Self {
let field = field.into();
assert_valid_sql_identifier(&field, "aggregate field");
Self {
func: AggregateFunc::Sum,
field: Some(field),
alias: None,
}
}
pub fn avg(field: impl Into<String>) -> Self {
let field = field.into();
assert_valid_sql_identifier(&field, "aggregate field");
Self {
func: AggregateFunc::Avg,
field: Some(field),
alias: None,
}
}
pub fn min(field: impl Into<String>) -> Self {
let field = field.into();
assert_valid_sql_identifier(&field, "aggregate field");
Self {
func: AggregateFunc::Min,
field: Some(field),
alias: None,
}
}
pub fn max(field: impl Into<String>) -> Self {
let field = field.into();
assert_valid_sql_identifier(&field, "aggregate field");
Self {
func: AggregateFunc::Max,
field: Some(field),
alias: None,
}
}
pub fn as_alias(mut self, alias: impl Into<String>) -> Self {
let alias = alias.into();
assert_valid_sql_identifier(&alias, "aggregate alias");
self.alias = Some(alias);
self
}
#[must_use]
pub fn to_sql(&self) -> String {
let expr = match (&self.func, &self.field) {
(AggregateFunc::Count, None) => "COUNT(*)".to_string(),
(AggregateFunc::Count, Some(f)) => format!("COUNT({f})"),
(AggregateFunc::CountDistinct, Some(f)) => format!("COUNT(DISTINCT {f})"),
(AggregateFunc::Sum, Some(f)) => format!("SUM({f})"),
(AggregateFunc::Avg, Some(f)) => format!("AVG({f})"),
(AggregateFunc::Min, Some(f)) => format!("MIN({f})"),
(AggregateFunc::Max, Some(f)) => format!("MAX({f})"),
_ => "COUNT(*)".to_string(),
};
match &self.alias {
Some(a) => format!("{expr} AS {a}"),
None => expr,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum Value {
Null,
Bool(bool),
Int(i64),
Float(f64),
String(String),
Array(Vec<Self>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum SortDir {
Asc,
Desc,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct SortField {
pub field: String,
pub dir: SortDir,
}
impl SortField {
pub fn new(field: impl Into<String>, dir: SortDir) -> Self {
Self {
field: field.into(),
dir,
}
}
pub fn parse_sort_string(sort: &str, allowed: &[&str]) -> Result<Vec<Self>, String> {
let mut result = Vec::new();
for part in sort.split(',') {
let part = part.trim();
if part.is_empty() {
continue;
}
let (field, dir) = part
.strip_prefix('-')
.map_or((part, SortDir::Asc), |stripped| (stripped, SortDir::Desc));
if !allowed.is_empty() && !allowed.contains(&field) {
return Err(format!(
"Sort field '{field}' not allowed. Allowed: {allowed:?}"
));
}
result.push(Self::new(field, dir));
}
Ok(result)
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct Filter {
pub field: String,
pub op: Operator,
pub value: Value,
}
impl Filter {
#[must_use]
pub fn new(field: impl Into<String>, op: Operator, value: Value) -> Self {
Self {
field: field.into(),
op,
value,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
#[must_use = "QueryResult must be used to execute the query"]
pub struct QueryResult {
pub sql: String,
pub params: Vec<Value>,
}
impl QueryResult {
#[must_use]
pub fn new(sql: impl Into<String>, params: Vec<Value>) -> Self {
Self {
sql: sql.into(),
params,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct ComputedField {
pub alias: String,
pub expression: String,
}
impl ComputedField {
pub fn new(alias: impl Into<String>, expression: impl Into<String>) -> Self {
Self {
alias: alias.into(),
expression: expression.into(),
}
}
#[must_use]
pub fn to_sql(&self) -> String {
format!("({}) AS {}", self.expression, self.alias)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CursorDirection {
After,
Before,
}
pub fn simple(field: impl Into<String>, op: Operator, value: Value) -> FilterExpr {
let field = field.into();
assert_valid_sql_identifier(&field, "filter field");
FilterExpr::Simple(Filter { field, op, value })
}
#[must_use]
pub const fn and(filters: Vec<FilterExpr>) -> FilterExpr {
FilterExpr::Compound(CompoundFilter::and(filters))
}
#[must_use]
pub const fn or(filters: Vec<FilterExpr>) -> FilterExpr {
FilterExpr::Compound(CompoundFilter::or(filters))
}
#[must_use]
pub fn not(filter: FilterExpr) -> FilterExpr {
FilterExpr::Compound(CompoundFilter::not(filter))
}