use std::collections::{HashMap, HashSet};
use crate::error::{Error, Result};
#[derive(Default)]
pub struct FilterSchema {
fields: HashMap<String, FieldType>,
sort_fields: HashSet<String>,
}
#[derive(Debug, Clone, Copy)]
pub enum FieldType {
Text,
Int,
Float,
Date,
Bool,
}
impl FilterSchema {
pub fn new() -> Self {
Self::default()
}
pub fn field(mut self, name: &str, typ: FieldType) -> Self {
self.fields.insert(name.to_string(), typ);
self
}
pub fn sort_fields(mut self, fields: &[&str]) -> Self {
self.sort_fields = fields.iter().map(|s| s.to_string()).collect();
self
}
fn field_type(&self, name: &str) -> Option<FieldType> {
self.fields.get(name).copied()
}
fn is_sort_field(&self, name: &str) -> bool {
self.sort_fields.contains(name)
}
}
#[derive(Debug, Clone)]
enum Operator {
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
Like,
IsNull(bool),
In,
}
#[derive(Debug, Clone)]
struct FilterCondition {
column: String,
operator: Operator,
values: Vec<String>,
}
pub struct Filter {
conditions: Vec<FilterCondition>,
sort: Vec<String>,
}
#[non_exhaustive]
pub struct ValidatedFilter {
pub clauses: Vec<String>,
pub params: Vec<libsql::Value>,
pub sort_clause: Option<String>,
}
impl ValidatedFilter {
pub fn is_empty(&self) -> bool {
self.clauses.is_empty()
}
}
impl Filter {
pub fn from_query_params(params: &HashMap<String, Vec<String>>) -> Self {
let mut conditions: HashMap<String, FilterCondition> = HashMap::new();
let mut sort = Vec::new();
for (key, values) in params {
if key == "sort" {
sort = values.clone();
continue;
}
if key == "page" || key == "per_page" || key == "after" {
continue;
}
let (column, op) = if let Some(dot_pos) = key.rfind('.') {
let col = &key[..dot_pos];
let op_str = &key[dot_pos + 1..];
let op = match op_str {
"ne" => Operator::Ne,
"gt" => Operator::Gt,
"gte" => Operator::Gte,
"lt" => Operator::Lt,
"lte" => Operator::Lte,
"like" => Operator::Like,
"null" => {
let is_null = values.first().map(|v| v == "true").unwrap_or(true);
Operator::IsNull(is_null)
}
_ => continue, };
(col.to_string(), op)
} else {
if values.len() > 1 {
(key.clone(), Operator::In)
} else {
(key.clone(), Operator::Eq)
}
};
conditions.insert(
key.to_string(),
FilterCondition {
column,
operator: op,
values: values.clone(),
},
);
}
Self {
conditions: conditions.into_values().collect(),
sort,
}
}
pub fn validate(self, schema: &FilterSchema) -> Result<ValidatedFilter> {
let mut clauses = Vec::new();
let mut params: Vec<libsql::Value> = Vec::new();
let mut conditions = self.conditions.clone();
conditions.sort_by(|a, b| a.column.cmp(&b.column));
for cond in &conditions {
let Some(field_type) = schema.field_type(&cond.column) else {
continue; };
match &cond.operator {
Operator::IsNull(is_null) => {
if *is_null {
clauses.push(format!("\"{}\" IS NULL", cond.column));
} else {
clauses.push(format!("\"{}\" IS NOT NULL", cond.column));
}
}
Operator::In => {
let placeholders: Vec<String> =
cond.values.iter().map(|_| "?".to_string()).collect();
clauses.push(format!(
"\"{}\" IN ({})",
cond.column,
placeholders.join(", ")
));
for val in &cond.values {
params.push(convert_value(val, field_type)?);
}
}
op => {
let sql_op = match op {
Operator::Eq => "=",
Operator::Ne => "!=",
Operator::Gt => ">",
Operator::Gte => ">=",
Operator::Lt => "<",
Operator::Lte => "<=",
Operator::Like => "LIKE",
_ => unreachable!(),
};
clauses.push(format!("\"{}\" {} ?", cond.column, sql_op));
let val = cond.values.first().ok_or_else(|| {
Error::bad_request(format!("missing value for filter '{}'", cond.column))
})?;
params.push(convert_value(val, field_type)?);
}
}
}
let sort_clause = {
let mut seen = HashSet::new();
let mut parts = Vec::new();
for s in &self.sort {
let (field, desc) = if let Some(stripped) = s.strip_prefix('-') {
(stripped, true)
} else {
(s.as_str(), false)
};
if schema.is_sort_field(field) && seen.insert(field) {
let direction = if desc { "DESC" } else { "ASC" };
parts.push(format!("\"{field}\" {direction}"));
}
}
if parts.is_empty() {
None
} else {
Some(parts.join(", "))
}
};
Ok(ValidatedFilter {
clauses,
params,
sort_clause,
})
}
}
fn convert_value(val: &str, field_type: FieldType) -> Result<libsql::Value> {
match field_type {
FieldType::Text | FieldType::Date => Ok(libsql::Value::from(val.to_string())),
FieldType::Int => {
let n: i64 = val
.parse()
.map_err(|_| Error::bad_request(format!("invalid integer value: '{val}'")))?;
Ok(libsql::Value::from(n))
}
FieldType::Float => {
let n: f64 = val
.parse()
.map_err(|_| Error::bad_request(format!("invalid float value: '{val}'")))?;
Ok(libsql::Value::from(n))
}
FieldType::Bool => match val {
"true" | "1" | "yes" => Ok(libsql::Value::from(1_i32)),
"false" | "0" | "no" => Ok(libsql::Value::from(0_i32)),
_ => Err(Error::bad_request(format!(
"invalid boolean value: '{val}' (expected true/false, 1/0, yes/no)"
))),
},
}
}
impl<S: Send + Sync> axum::extract::FromRequestParts<S> for Filter {
type Rejection = crate::error::Error;
async fn from_request_parts(
parts: &mut http::request::Parts,
_state: &S,
) -> std::result::Result<Self, Self::Rejection> {
let uri = &parts.uri;
let query = uri.query().unwrap_or("");
let mut params: HashMap<String, Vec<String>> = HashMap::new();
for pair in query.split('&') {
if pair.is_empty() {
continue;
}
let (key, value) = match pair.split_once('=') {
Some((k, v)) => (k, v),
None => (pair, ""),
};
let key = urlencoding::decode(key)
.unwrap_or_else(|_| key.into())
.to_string();
let value = urlencoding::decode(value)
.unwrap_or_else(|_| value.into())
.to_string();
params.entry(key).or_default().push(value);
}
Ok(Filter::from_query_params(¶ms))
}
}