1use std::collections::{HashMap, HashSet};
2
3use crate::error::{Error, Result};
4
5#[derive(Default)]
11pub struct FilterSchema {
12 fields: HashMap<String, FieldType>,
13 sort_fields: HashSet<String>,
14}
15
16#[derive(Debug, Clone, Copy)]
21pub enum FieldType {
22 Text,
23 Int,
24 Float,
25 Date,
26 Bool,
27}
28
29impl FilterSchema {
30 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn field(mut self, name: &str, typ: FieldType) -> Self {
37 self.fields.insert(name.to_string(), typ);
38 self
39 }
40
41 pub fn sort_fields(mut self, fields: &[&str]) -> Self {
43 self.sort_fields = fields.iter().map(|s| s.to_string()).collect();
44 self
45 }
46
47 fn field_type(&self, name: &str) -> Option<FieldType> {
48 self.fields.get(name).copied()
49 }
50
51 fn is_sort_field(&self, name: &str) -> bool {
52 self.sort_fields.contains(name)
53 }
54}
55
56#[derive(Debug, Clone)]
58enum Operator {
59 Eq,
60 Ne,
61 Gt,
62 Gte,
63 Lt,
64 Lte,
65 Like,
66 IsNull(bool),
67 In,
68}
69
70#[derive(Debug, Clone)]
72struct FilterCondition {
73 column: String,
74 operator: Operator,
75 values: Vec<String>,
76}
77
78pub struct Filter {
98 conditions: Vec<FilterCondition>,
99 sort: Vec<String>,
100}
101
102#[non_exhaustive]
107pub struct ValidatedFilter {
108 pub clauses: Vec<String>,
110 pub params: Vec<libsql::Value>,
112 pub sort_clause: Option<String>,
114}
115
116impl ValidatedFilter {
117 pub fn is_empty(&self) -> bool {
119 self.clauses.is_empty()
120 }
121}
122
123impl Filter {
124 pub fn from_query_params(params: &HashMap<String, Vec<String>>) -> Self {
129 let mut conditions: HashMap<String, FilterCondition> = HashMap::new();
130 let mut sort = Vec::new();
131
132 for (key, values) in params {
133 if key == "sort" {
134 sort = values.clone();
135 continue;
136 }
137
138 if key == "page" || key == "per_page" || key == "after" {
140 continue;
141 }
142
143 let (column, op) = if let Some(dot_pos) = key.rfind('.') {
145 let col = &key[..dot_pos];
146 let op_str = &key[dot_pos + 1..];
147 let op = match op_str {
148 "ne" => Operator::Ne,
149 "gt" => Operator::Gt,
150 "gte" => Operator::Gte,
151 "lt" => Operator::Lt,
152 "lte" => Operator::Lte,
153 "like" => Operator::Like,
154 "null" => {
155 let is_null = values.first().map(|v| v == "true").unwrap_or(true);
156 Operator::IsNull(is_null)
157 }
158 _ => continue, };
160 (col.to_string(), op)
161 } else {
162 if values.len() > 1 {
164 (key.clone(), Operator::In)
165 } else {
166 (key.clone(), Operator::Eq)
167 }
168 };
169
170 conditions.insert(
171 key.to_string(),
172 FilterCondition {
173 column,
174 operator: op,
175 values: values.clone(),
176 },
177 );
178 }
179
180 Self {
181 conditions: conditions.into_values().collect(),
182 sort,
183 }
184 }
185
186 pub fn validate(self, schema: &FilterSchema) -> Result<ValidatedFilter> {
196 let mut clauses = Vec::new();
197 let mut params: Vec<libsql::Value> = Vec::new();
198
199 let mut conditions = self.conditions.clone();
200 conditions.sort_by(|a, b| a.column.cmp(&b.column));
201
202 for cond in &conditions {
203 let Some(field_type) = schema.field_type(&cond.column) else {
204 continue; };
206
207 match &cond.operator {
208 Operator::IsNull(is_null) => {
209 if *is_null {
210 clauses.push(format!("\"{}\" IS NULL", cond.column));
211 } else {
212 clauses.push(format!("\"{}\" IS NOT NULL", cond.column));
213 }
214 }
215 Operator::In => {
216 let placeholders: Vec<String> =
217 cond.values.iter().map(|_| "?".to_string()).collect();
218 clauses.push(format!(
219 "\"{}\" IN ({})",
220 cond.column,
221 placeholders.join(", ")
222 ));
223 for val in &cond.values {
224 params.push(convert_value(val, field_type)?);
225 }
226 }
227 op => {
228 let sql_op = match op {
229 Operator::Eq => "=",
230 Operator::Ne => "!=",
231 Operator::Gt => ">",
232 Operator::Gte => ">=",
233 Operator::Lt => "<",
234 Operator::Lte => "<=",
235 Operator::Like => "LIKE",
236 _ => unreachable!(),
237 };
238 clauses.push(format!("\"{}\" {} ?", cond.column, sql_op));
239 let val = cond.values.first().ok_or_else(|| {
240 Error::bad_request(format!("missing value for filter '{}'", cond.column))
241 })?;
242 params.push(convert_value(val, field_type)?);
243 }
244 }
245 }
246
247 let sort_clause = {
249 let mut seen = HashSet::new();
250 let mut parts = Vec::new();
251 for s in &self.sort {
252 let (field, desc) = if let Some(stripped) = s.strip_prefix('-') {
253 (stripped, true)
254 } else {
255 (s.as_str(), false)
256 };
257 if schema.is_sort_field(field) && seen.insert(field) {
258 let direction = if desc { "DESC" } else { "ASC" };
259 parts.push(format!("\"{field}\" {direction}"));
260 }
261 }
262 if parts.is_empty() {
263 None
264 } else {
265 Some(parts.join(", "))
266 }
267 };
268
269 Ok(ValidatedFilter {
270 clauses,
271 params,
272 sort_clause,
273 })
274 }
275}
276
277fn convert_value(val: &str, field_type: FieldType) -> Result<libsql::Value> {
278 match field_type {
279 FieldType::Text | FieldType::Date => Ok(libsql::Value::from(val.to_string())),
280 FieldType::Int => {
281 let n: i64 = val
282 .parse()
283 .map_err(|_| Error::bad_request(format!("invalid integer value: '{val}'")))?;
284 Ok(libsql::Value::from(n))
285 }
286 FieldType::Float => {
287 let n: f64 = val
288 .parse()
289 .map_err(|_| Error::bad_request(format!("invalid float value: '{val}'")))?;
290 Ok(libsql::Value::from(n))
291 }
292 FieldType::Bool => match val {
293 "true" | "1" | "yes" => Ok(libsql::Value::from(1_i32)),
294 "false" | "0" | "no" => Ok(libsql::Value::from(0_i32)),
295 _ => Err(Error::bad_request(format!(
296 "invalid boolean value: '{val}' (expected true/false, 1/0, yes/no)"
297 ))),
298 },
299 }
300}
301
302impl<S: Send + Sync> axum::extract::FromRequestParts<S> for Filter {
304 type Rejection = crate::error::Error;
305
306 async fn from_request_parts(
307 parts: &mut http::request::Parts,
308 _state: &S,
309 ) -> std::result::Result<Self, Self::Rejection> {
310 let uri = &parts.uri;
311 let query = uri.query().unwrap_or("");
312
313 let mut params: HashMap<String, Vec<String>> = HashMap::new();
315 for pair in query.split('&') {
316 if pair.is_empty() {
317 continue;
318 }
319 let (key, value) = match pair.split_once('=') {
320 Some((k, v)) => (k, v),
321 None => (pair, ""),
322 };
323 let key = urlencoding::decode(key)
324 .unwrap_or_else(|_| key.into())
325 .to_string();
326 let value = urlencoding::decode(value)
327 .unwrap_or_else(|_| value.into())
328 .to_string();
329 params.entry(key).or_default().push(value);
330 }
331
332 Ok(Filter::from_query_params(¶ms))
333 }
334}