Skip to main content

modo/db/
filter.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::error::{Error, Result};
4
5/// Declares the allowed filter fields and sort fields for an endpoint.
6///
7/// Build one per endpoint using the builder methods [`field`](Self::field)
8/// and [`sort_fields`](Self::sort_fields), then pass it to
9/// [`Filter::validate`] to produce a [`ValidatedFilter`].
10#[derive(Default)]
11pub struct FilterSchema {
12    fields: HashMap<String, FieldType>,
13    sort_fields: HashSet<String>,
14}
15
16/// Column type used for validating filter values.
17///
18/// Determines how string query-parameter values are converted to
19/// `libsql::Value` during [`Filter::validate`].
20#[derive(Debug, Clone, Copy)]
21pub enum FieldType {
22    Text,
23    Int,
24    Float,
25    Date,
26    Bool,
27}
28
29impl FilterSchema {
30    /// Create an empty schema.
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Add an allowed filter field with its expected type.
36    pub fn field(mut self, name: &str, typ: FieldType) -> Self {
37        self.fields.insert(name.to_string(), typ);
38        self
39    }
40
41    /// Set the allowed sort field names.
42    pub fn sort_fields(mut self, fields: &[&str]) -> Self {
43        self.sort_fields = fields.iter().map(|s| s.to_string()).collect();
44        self
45    }
46
47    fn field_type(&self, name: &str) -> Option<FieldType> {
48        self.fields.get(name).copied()
49    }
50
51    fn is_sort_field(&self, name: &str) -> bool {
52        self.sort_fields.contains(name)
53    }
54}
55
56/// Parsed operator from query string.
57#[derive(Debug, Clone)]
58enum Operator {
59    Eq,
60    Ne,
61    Gt,
62    Gte,
63    Lt,
64    Lte,
65    Like,
66    IsNull(bool),
67    In,
68}
69
70/// A single parsed filter condition.
71#[derive(Debug, Clone)]
72struct FilterCondition {
73    column: String,
74    operator: Operator,
75    values: Vec<String>,
76}
77
78/// Raw parsed filter from query string.
79///
80/// Implements `FromRequestParts` so it can be used directly as an axum
81/// handler argument. Must be validated against a [`FilterSchema`] via
82/// [`validate`](Self::validate) before use in SQL generation.
83///
84/// ## Supported query-string syntax
85///
86/// | Pattern | Meaning |
87/// |---------|---------|
88/// | `field=value` | Equality (`=`), or `IN` if multiple values |
89/// | `field.ne=value` | Not equal (`!=`) |
90/// | `field.gt=value` | Greater than (`>`) |
91/// | `field.gte=value` | Greater than or equal (`>=`) |
92/// | `field.lt=value` | Less than (`<`) |
93/// | `field.lte=value` | Less than or equal (`<=`) |
94/// | `field.like=value` | SQL `LIKE` |
95/// | `field.null=true` | `IS NULL` / `IS NOT NULL` |
96/// | `sort=field` | Sort ascending; `sort=-field` for descending; repeat for multi-column |
97pub struct Filter {
98    conditions: Vec<FilterCondition>,
99    sort: Vec<String>,
100}
101
102/// Schema-validated filter, safe for SQL generation.
103///
104/// Produced by [`Filter::validate`]. Contains parameterized WHERE clauses
105/// and an optional ORDER BY clause. Used by [`SelectBuilder`](super::SelectBuilder).
106#[non_exhaustive]
107pub struct ValidatedFilter {
108    /// WHERE clause fragments (joined with `AND`).
109    pub clauses: Vec<String>,
110    /// Bind parameters corresponding to `?` placeholders in `clauses`.
111    pub params: Vec<libsql::Value>,
112    /// Optional ORDER BY clause from the `sort` parameter.
113    pub sort_clause: Option<String>,
114}
115
116impl ValidatedFilter {
117    /// Returns `true` if no filter conditions were produced.
118    pub fn is_empty(&self) -> bool {
119        self.clauses.is_empty()
120    }
121}
122
123impl Filter {
124    /// Parse filter conditions from a pre-parsed query string map.
125    ///
126    /// Pagination parameters (`page`, `per_page`, `after`) are silently
127    /// skipped. Unknown operators are ignored.
128    pub fn from_query_params(params: &HashMap<String, Vec<String>>) -> Self {
129        let mut conditions: HashMap<String, FilterCondition> = HashMap::new();
130        let mut sort = Vec::new();
131
132        for (key, values) in params {
133            if key == "sort" {
134                sort = values.clone();
135                continue;
136            }
137
138            // Skip pagination params
139            if key == "page" || key == "per_page" || key == "after" {
140                continue;
141            }
142
143            // Parse operator from key: "field.op" or just "field"
144            let (column, op) = if let Some(dot_pos) = key.rfind('.') {
145                let col = &key[..dot_pos];
146                let op_str = &key[dot_pos + 1..];
147                let op = match op_str {
148                    "ne" => Operator::Ne,
149                    "gt" => Operator::Gt,
150                    "gte" => Operator::Gte,
151                    "lt" => Operator::Lt,
152                    "lte" => Operator::Lte,
153                    "like" => Operator::Like,
154                    "null" => {
155                        let is_null = values.first().map(|v| v == "true").unwrap_or(true);
156                        Operator::IsNull(is_null)
157                    }
158                    _ => continue, // Unknown operator — skip
159                };
160                (col.to_string(), op)
161            } else {
162                // No operator — Eq (single value) or In (multiple values)
163                if values.len() > 1 {
164                    (key.clone(), Operator::In)
165                } else {
166                    (key.clone(), Operator::Eq)
167                }
168            };
169
170            conditions.insert(
171                key.to_string(),
172                FilterCondition {
173                    column,
174                    operator: op,
175                    values: values.clone(),
176                },
177            );
178        }
179
180        Self {
181            conditions: conditions.into_values().collect(),
182            sort,
183        }
184    }
185
186    /// Validate against a schema, producing a [`ValidatedFilter`].
187    ///
188    /// Unknown columns are silently ignored. Sort fields not listed in the
189    /// schema are dropped.
190    ///
191    /// # Errors
192    ///
193    /// Returns a 400 error if a filter value cannot be converted to the
194    /// declared [`FieldType`] (e.g., `"abc"` for an `Int` field).
195    pub fn validate(self, schema: &FilterSchema) -> Result<ValidatedFilter> {
196        let mut clauses = Vec::new();
197        let mut params: Vec<libsql::Value> = Vec::new();
198
199        let mut conditions = self.conditions.clone();
200        conditions.sort_by(|a, b| a.column.cmp(&b.column));
201
202        for cond in &conditions {
203            let Some(field_type) = schema.field_type(&cond.column) else {
204                continue; // Unknown column — silently ignore
205            };
206
207            match &cond.operator {
208                Operator::IsNull(is_null) => {
209                    if *is_null {
210                        clauses.push(format!("\"{}\" IS NULL", cond.column));
211                    } else {
212                        clauses.push(format!("\"{}\" IS NOT NULL", cond.column));
213                    }
214                }
215                Operator::In => {
216                    let placeholders: Vec<String> =
217                        cond.values.iter().map(|_| "?".to_string()).collect();
218                    clauses.push(format!(
219                        "\"{}\" IN ({})",
220                        cond.column,
221                        placeholders.join(", ")
222                    ));
223                    for val in &cond.values {
224                        params.push(convert_value(val, field_type)?);
225                    }
226                }
227                op => {
228                    let sql_op = match op {
229                        Operator::Eq => "=",
230                        Operator::Ne => "!=",
231                        Operator::Gt => ">",
232                        Operator::Gte => ">=",
233                        Operator::Lt => "<",
234                        Operator::Lte => "<=",
235                        Operator::Like => "LIKE",
236                        _ => unreachable!(),
237                    };
238                    clauses.push(format!("\"{}\" {} ?", cond.column, sql_op));
239                    let val = cond.values.first().ok_or_else(|| {
240                        Error::bad_request(format!("missing value for filter '{}'", cond.column))
241                    })?;
242                    params.push(convert_value(val, field_type)?);
243                }
244            }
245        }
246
247        // Validate sort
248        let sort_clause = {
249            let mut seen = HashSet::new();
250            let mut parts = Vec::new();
251            for s in &self.sort {
252                let (field, desc) = if let Some(stripped) = s.strip_prefix('-') {
253                    (stripped, true)
254                } else {
255                    (s.as_str(), false)
256                };
257                if schema.is_sort_field(field) && seen.insert(field) {
258                    let direction = if desc { "DESC" } else { "ASC" };
259                    parts.push(format!("\"{field}\" {direction}"));
260                }
261            }
262            if parts.is_empty() {
263                None
264            } else {
265                Some(parts.join(", "))
266            }
267        };
268
269        Ok(ValidatedFilter {
270            clauses,
271            params,
272            sort_clause,
273        })
274    }
275}
276
277fn convert_value(val: &str, field_type: FieldType) -> Result<libsql::Value> {
278    match field_type {
279        FieldType::Text | FieldType::Date => Ok(libsql::Value::from(val.to_string())),
280        FieldType::Int => {
281            let n: i64 = val
282                .parse()
283                .map_err(|_| Error::bad_request(format!("invalid integer value: '{val}'")))?;
284            Ok(libsql::Value::from(n))
285        }
286        FieldType::Float => {
287            let n: f64 = val
288                .parse()
289                .map_err(|_| Error::bad_request(format!("invalid float value: '{val}'")))?;
290            Ok(libsql::Value::from(n))
291        }
292        FieldType::Bool => match val {
293            "true" | "1" | "yes" => Ok(libsql::Value::from(1_i32)),
294            "false" | "0" | "no" => Ok(libsql::Value::from(0_i32)),
295            _ => Err(Error::bad_request(format!(
296                "invalid boolean value: '{val}' (expected true/false, 1/0, yes/no)"
297            ))),
298        },
299    }
300}
301
302// axum extractor
303impl<S: Send + Sync> axum::extract::FromRequestParts<S> for Filter {
304    type Rejection = crate::error::Error;
305
306    async fn from_request_parts(
307        parts: &mut http::request::Parts,
308        _state: &S,
309    ) -> std::result::Result<Self, Self::Rejection> {
310        let uri = &parts.uri;
311        let query = uri.query().unwrap_or("");
312
313        // Parse query string into HashMap<String, Vec<String>>
314        let mut params: HashMap<String, Vec<String>> = HashMap::new();
315        for pair in query.split('&') {
316            if pair.is_empty() {
317                continue;
318            }
319            let (key, value) = match pair.split_once('=') {
320                Some((k, v)) => (k, v),
321                None => (pair, ""),
322            };
323            let key = urlencoding::decode(key)
324                .unwrap_or_else(|_| key.into())
325                .to_string();
326            let value = urlencoding::decode(value)
327                .unwrap_or_else(|_| value.into())
328                .to_string();
329            params.entry(key).or_default().push(value);
330        }
331
332        Ok(Filter::from_query_params(&params))
333    }
334}