Skip to main content

trailbase_qs/
query.rs

1use base64::prelude::*;
2use itertools::Itertools;
3use serde::Deserialize;
4
5use crate::filter::ValueOrComposite;
6use crate::util::deserialize_bool;
7
8pub type Error = serde_qs::Error;
9
10#[derive(Clone, Debug, PartialEq)]
11pub enum CursorType {
12  Blob,
13  Integer,
14}
15
16/// TrailBase supports cursors in a few formats:
17///  * Integers
18///  * Text-encoded UUIDs ([u8; 16])
19///  * Url-safe base64 encoded blobs including UUIDs.
20///
21/// In practice, we should just support integers and generically blobs. In the future way may want
22/// to use encrypted cursors, which would also just be arbitrary url-safe base64 encoded bytes.
23#[derive(Clone, Debug, PartialEq)]
24pub enum Cursor {
25  Blob(Vec<u8>),
26  Integer(i64),
27}
28
29impl Cursor {
30  pub fn parse(s: &str, cursor_type: CursorType) -> Result<Self, Error> {
31    return match cursor_type {
32      CursorType::Integer => {
33        let i = s.parse::<i64>().map_err(Error::ParseInt)?;
34        Ok(Self::Integer(i))
35      }
36      CursorType::Blob => {
37        if let Ok(uuid) = uuid::Uuid::parse_str(s) {
38          return Ok(Cursor::Blob(uuid.into()));
39        }
40
41        if let Ok(base64) = BASE64_URL_SAFE.decode(s) {
42          return Ok(Cursor::Blob(base64));
43        }
44
45        Err(Error::Custom(format!("Failed to parse: {s}")))
46      }
47    };
48  }
49}
50
51#[derive(Clone, Debug, PartialEq)]
52pub enum OrderPrecedent {
53  Ascending,
54  Descending,
55}
56
57#[derive(Clone, Debug, PartialEq)]
58pub struct Order {
59  pub columns: Vec<(String, OrderPrecedent)>,
60}
61
62impl<'de> serde::de::Deserialize<'de> for Order {
63  fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64  where
65    D: serde::de::Deserializer<'de>,
66  {
67    use serde::de::Error;
68    use serde_value::Value;
69
70    let value = Value::deserialize(deserializer)?;
71    let Value::String(str) = value else {
72      return Err(Error::invalid_type(
73        crate::util::unexpected(&value),
74        &"comma separated column names to order by",
75      ));
76    };
77
78    let columns = str
79      .split(",")
80      .map(|v| {
81        let col_order = match v.trim() {
82          x if x.starts_with("-") => (v[1..].to_string(), OrderPrecedent::Descending),
83          x if x.starts_with("+") => (v[1..].to_string(), OrderPrecedent::Ascending),
84          x => (x.to_string(), OrderPrecedent::Ascending),
85        };
86
87        if !crate::util::sanitize_column_name(&col_order.0) {
88          return Err(Error::custom(format!(
89            "invalid column name for order: {}",
90            col_order.0
91          )));
92        }
93
94        return Ok(col_order);
95      })
96      .collect::<Result<Vec<_>, _>>()?;
97
98    if columns.len() > 5 {
99      return Err(Error::invalid_length(
100        5,
101        &"more more than 5 order dimension",
102      ));
103    }
104
105    return Ok(Order { columns });
106  }
107}
108
109#[derive(Clone, Debug, PartialEq)]
110pub struct Expand {
111  pub columns: Vec<String>,
112}
113
114impl<'de> serde::de::Deserialize<'de> for Expand {
115  fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
116  where
117    D: serde::de::Deserializer<'de>,
118  {
119    use serde::de::Error;
120    use serde_value::Value;
121
122    let value = Value::deserialize(deserializer)?;
123    let Value::String(str) = value else {
124      return Err(Error::invalid_type(
125        crate::util::unexpected(&value),
126        &"comma separated foreign-key column names to expand",
127      ));
128    };
129
130    let columns = str
131      .split(",")
132      .map(|column_name| {
133        if !crate::util::sanitize_column_name(column_name) {
134          return Err(Error::custom(format!(
135            "invalid column name for expand: {column_name}",
136          )));
137        }
138
139        return Ok(column_name.to_string());
140      })
141      .collect::<Result<Vec<_>, _>>()?;
142
143    if columns.len() > 5 {
144      return Err(Error::invalid_length(
145        5,
146        &"more more than 5 expand dimension",
147      ));
148    }
149
150    return Ok(Expand { columns });
151  }
152}
153
154#[derive(Clone, Default, Debug, PartialEq, Deserialize)]
155pub struct Query {
156  /// Pagination parameters:
157  ///
158  /// Max number of elements returned per page.
159  pub limit: Option<usize>,
160  /// Cursor to page.
161  pub cursor: Option<String>,
162  /// Offset to page. Cursor is more efficient when available
163  pub offset: Option<usize>,
164
165  /// Return total number of rows in the table.
166  #[serde(default, deserialize_with = "deserialize_bool")]
167  pub count: Option<bool>,
168
169  /// Which foreign key columns to expand (only when allowed by configuration).
170  pub expand: Option<Expand>,
171
172  /// Ordering. It's a vector for &order=-col0,+col1,col2
173  pub order: Option<Order>,
174
175  /// Map from filter params to filter value. It's a vector in cases like:
176  ///   `col0[$gte]=2&col0[$lte]=10`.
177  pub filter: Option<ValueOrComposite>,
178}
179
180impl Query {
181  pub fn parse(query: &str) -> Result<Query, Error> {
182    // NOTE: We rely on non-strict mode to parse `filter[col0]=a&b%filter[col1]=c`.
183    let qs = serde_qs::Config::new(9, false);
184    return qs.deserialize_bytes::<Query>(query.as_bytes());
185  }
186
187  /// Produce a query-string representation of this `Query`.
188  pub fn to_query(&self) -> String {
189    let mut pairs: Vec<String> = vec![];
190
191    if let Some(limit) = self.limit {
192      pairs.push(format!("limit={limit}"));
193    }
194
195    if let Some(ref cursor) = self.cursor {
196      pairs.push(format!("cursor={cursor}"));
197    }
198
199    if let Some(offset) = self.offset {
200      pairs.push(format!("offset={offset}"));
201    }
202
203    if let Some(count) = self.count {
204      pairs.push(format!("count={}", if count { "true" } else { "false" }));
205    }
206
207    if let Some(ref expand) = self.expand {
208      let s = expand.columns.join(",");
209      pairs.push(format!("expand={s}"));
210    }
211
212    if let Some(ref order) = self.order {
213      let s = order
214        .columns
215        .iter()
216        .map(|(c, p)| match p {
217          crate::query::OrderPrecedent::Descending => format!("-{}", c),
218          crate::query::OrderPrecedent::Ascending => c.to_string(),
219        })
220        .join(",");
221
222      pairs.push(format!("order={s}"));
223    }
224
225    if let Some(ref filter) = self.filter {
226      pairs.push(filter.to_query());
227    }
228
229    return pairs.into_iter().join("&");
230  }
231}
232
233#[derive(Clone, Default, Debug, PartialEq, Deserialize)]
234pub struct FilterQuery {
235  /// Map from filter params to filter value. It's a vector in cases like:
236  ///   `col0[$gte]=2&col0[$lte]=10`.
237  pub filter: Option<ValueOrComposite>,
238}
239
240impl FilterQuery {
241  pub fn parse(query: &str) -> Result<FilterQuery, Error> {
242    // NOTE: We rely on non-strict mode to parse `filter[col0]=a&b%filter[col1]=c`.
243    let qs = serde_qs::Config::new(9, false);
244    return qs.deserialize_bytes::<FilterQuery>(query.as_bytes());
245  }
246
247  /// Produce query string for only the filter part.
248  pub fn to_query(&self) -> String {
249    if let Some(ref filter) = self.filter {
250      return filter.to_query();
251    }
252    return "".to_string();
253  }
254}
255
256#[cfg(test)]
257mod tests {
258  use super::*;
259
260  use rusqlite::types::Value as SqlValue;
261  use serde_qs::Config;
262
263  use crate::column_rel_value::{ColumnOpValue, CompareOp};
264  use crate::filter::Combiner;
265  use crate::value::Value;
266
267  #[test]
268  fn test_query_basic_parsing() {
269    assert_eq!(Query::parse("").unwrap(), Query::default());
270    assert_eq!(Query::parse("unknown=foo").unwrap(), Query::default());
271
272    // NOTE: The filter value contains a '&', which will not parse in serde_qs strict-mode. Test
273    // explicitly that we properly allow '&'s.
274    assert_eq!(
275      Query::parse("filter%5Btext_not_null%5D=rust+client+test+0%3A+%3D%3F%261747466199")
276        .unwrap()
277        .filter
278        .unwrap(),
279      ValueOrComposite::Value(ColumnOpValue {
280        column: "text_not_null".to_string(),
281        op: CompareOp::Equal,
282        value: Value::String("rust client test 0: =?&1747466199".to_string()),
283      })
284    );
285
286    let expected = ValueOrComposite::Composite(
287      Combiner::And,
288      vec![
289        ValueOrComposite::Composite(
290          Combiner::Or,
291          vec![
292            ValueOrComposite::Value(ColumnOpValue {
293              column: "latency".to_string(),
294              op: CompareOp::GreaterThan,
295              value: Value::Integer(2),
296            }),
297            ValueOrComposite::Value(ColumnOpValue {
298              column: "status".to_string(),
299              op: CompareOp::GreaterThanEqual,
300              value: Value::Integer(400),
301            }),
302          ],
303        ),
304        ValueOrComposite::Value(ColumnOpValue {
305          column: "latency".to_string(),
306          op: CompareOp::GreaterThan,
307          value: Value::Integer(2),
308        }),
309      ],
310    );
311
312    // Make sure depth in the parse config is set large enough to also parse more deeply composed
313    // expressions.
314    assert_eq!(
315      Query::parse("filter[$and][0][$or][0][latency][$gt]=2&filter[$and][0][$or][1][status][$gte]=400&filter[$and][1][latency][$gt]=2")
316        .unwrap()
317        .filter
318        .unwrap(),
319        expected
320    );
321
322    assert_eq!(
323      Query::parse("limit=5&offset=5&count=true").unwrap(),
324      Query {
325        limit: Some(5),
326        offset: Some(5),
327        count: Some(true),
328        ..Default::default()
329      }
330    );
331    assert_eq!(
332      Query::parse("count=FALSE").unwrap(),
333      Query {
334        count: Some(false),
335        ..Default::default()
336      }
337    );
338    assert!(Query::parse("offset=-1").is_err());
339  }
340
341  #[test]
342  fn test_basic_to_query() {
343    let q = Query {
344      limit: Some(10),
345      cursor: Some("-5".to_string()),
346      offset: Some(2),
347      count: Some(true),
348      expand: Some(Expand {
349        columns: vec!["a".to_string(), "b".to_string()],
350      }),
351      order: Some(Order {
352        columns: vec![
353          ("a".to_string(), OrderPrecedent::Ascending),
354          ("b".to_string(), OrderPrecedent::Descending),
355        ],
356      }),
357      filter: None,
358    };
359
360    let s = q.to_query();
361    // Order of params isn't strictly specified; check presence of important fragments.
362    assert!(s.contains("limit=10"));
363    assert!(s.contains("cursor=-5"));
364    assert!(s.contains("offset=2"));
365    assert!(s.contains("count=true"));
366    assert!(s.contains("expand=a,b"));
367    assert!(s.contains("order=a,-b"));
368  }
369
370  #[test]
371  fn test_query_order_parsing() {
372    let qs = Config::new(5, false);
373
374    assert_eq!(
375      Query::parse("order=").unwrap(),
376      Query {
377        order: None,
378        ..Default::default()
379      },
380    );
381
382    assert!(qs.deserialize_str::<Query>("order=$").is_err());
383    assert!(qs.deserialize_str::<Query>("order=a,b,c,d,e").is_ok());
384    assert!(qs.deserialize_str::<Query>("order=a,b,c,d,e,f").is_err());
385
386    assert_eq!(
387      qs.deserialize_str::<Query>("order=a,-b,+c").unwrap(),
388      Query {
389        order: Some(Order {
390          columns: vec![
391            ("a".to_string(), OrderPrecedent::Ascending),
392            ("b".to_string(), OrderPrecedent::Descending),
393            ("c".to_string(), OrderPrecedent::Ascending),
394          ]
395        }),
396        ..Default::default()
397      }
398    );
399  }
400
401  #[test]
402  fn test_query_expand_parsing() {
403    let qs = Config::new(5, false);
404
405    assert_eq!(
406      qs.deserialize_str::<Query>("expand=").unwrap(),
407      Query {
408        expand: None,
409        ..Default::default()
410      },
411    );
412
413    assert!(qs.deserialize_str::<Query>("expand=$").is_err());
414    assert!(qs.deserialize_str::<Query>("expand=a,b,c,d,e").is_ok());
415    assert!(qs.deserialize_str::<Query>("expand=a,b,c,d,e,f").is_err());
416  }
417
418  #[test]
419  fn test_query_filter_parsing() {
420    let qs = Config::new(5, false);
421
422    assert_eq!(
423      qs.deserialize_str::<Query>("filter=").unwrap(),
424      Query::default()
425    );
426
427    let q0: Query = qs
428      .deserialize_str("filter[col0][$gt]=0&filter[col1]=val1")
429      .unwrap();
430    assert_eq!(
431      q0.filter.unwrap(),
432      ValueOrComposite::Composite(
433        Combiner::And,
434        vec![
435          ValueOrComposite::Value(ColumnOpValue {
436            column: "col0".to_string(),
437            op: CompareOp::GreaterThan,
438            value: Value::Integer(0),
439          }),
440          ValueOrComposite::Value(ColumnOpValue {
441            column: "col1".to_string(),
442            op: CompareOp::Equal,
443            value: Value::String("val1".to_string()),
444          }),
445        ]
446      )
447    );
448
449    // Implicit and with nested or and out of order.
450    let q1: Query = qs
451      .deserialize_str("filter[$or][1][col0][$ne]=val0&filter[col1]=1&filter[$or][0][col2]=val2")
452      .unwrap();
453    assert_eq!(
454      q1.filter.as_ref().unwrap(),
455      &ValueOrComposite::Composite(
456        Combiner::And,
457        vec![
458          ValueOrComposite::Composite(
459            Combiner::Or,
460            vec![
461              ValueOrComposite::Value(ColumnOpValue {
462                column: "col2".to_string(),
463                op: CompareOp::Equal,
464                value: Value::String("val2".to_string()),
465              }),
466              ValueOrComposite::Value(ColumnOpValue {
467                column: "col0".to_string(),
468                op: CompareOp::NotEqual,
469                value: Value::String("val0".to_string()),
470              }),
471            ]
472          ),
473          ValueOrComposite::Value(ColumnOpValue {
474            column: "col1".to_string(),
475            op: CompareOp::Equal,
476            value: Value::Integer(1),
477          }),
478        ]
479      )
480    );
481
482    fn convert(_: &str, value: Value) -> Result<SqlValue, String> {
483      return Ok(match value {
484        Value::String(s) => SqlValue::Text(s),
485        Value::Integer(i) => SqlValue::Integer(i),
486        Value::Double(d) => SqlValue::Real(d),
487      });
488    }
489
490    let (sql, params) = q1.filter.clone().unwrap().into_sql(None, &convert).unwrap();
491    assert_eq!(
492      sql,
493      r#"(("col2" = :__p0 OR "col0" <> :__p1) AND "col1" = :__p2)"#
494    );
495    assert_eq!(
496      params,
497      vec![
498        (":__p0".to_string(), SqlValue::Text("val2".to_string())),
499        (":__p1".to_string(), SqlValue::Text("val0".to_string())),
500        (":__p2".to_string(), SqlValue::Integer(1)),
501      ]
502    );
503    let (sql, _) = q1.filter.unwrap().into_sql(Some("p"), &convert).unwrap();
504    assert_eq!(
505      sql,
506      r#"((p."col2" = :__p0 OR p."col0" <> :__p1) AND p."col1" = :__p2)"#
507    );
508
509    // Test both encodings: '+' and %20 for ' '.
510    let q2: Query = qs
511      .deserialize_str("filter[col]=with+white%20spaces")
512      .unwrap();
513    assert_eq!(
514      q2.filter.unwrap(),
515      ValueOrComposite::Value(ColumnOpValue {
516        column: "col".to_string(),
517        op: CompareOp::Equal,
518        value: Value::String("with white spaces".to_string()),
519      }),
520    );
521  }
522
523  #[test]
524  fn test_date_range_filter() {
525    // Test that multiple operators on the same column (e.g., date range filters) work correctly
526    let result =
527      Query::parse("filter[datetime][$gte]=2025-09-25&filter[datetime][$lte]=2025-09-27");
528
529    let query = result.expect("Should parse date range filter");
530    let filter = query.filter.expect("Should have filter");
531
532    // Verify it creates an AND composite with two conditions
533    match filter {
534      ValueOrComposite::Composite(Combiner::And, values) => {
535        assert_eq!(values.len(), 2, "Should have two date conditions");
536
537        // Check the conditions are correct
538        if let ValueOrComposite::Value(first) = &values[0] {
539          assert_eq!(first.column, "datetime");
540          assert_eq!(first.op, CompareOp::GreaterThanEqual);
541        }
542
543        if let ValueOrComposite::Value(second) = &values[1] {
544          assert_eq!(second.column, "datetime");
545          assert_eq!(second.op, CompareOp::LessThanEqual);
546        }
547      }
548      _ => panic!("Expected AND composite filter for date range"),
549    }
550  }
551
552  #[test]
553  fn test_query_cursor_parsing() {
554    let qs = Config::new(5, false);
555
556    assert_eq!(
557      qs.deserialize_str::<Query>("cursor=").unwrap(),
558      Query::default()
559    );
560
561    assert_eq!(
562      qs.deserialize_str::<Query>("cursor=-5").unwrap(),
563      Query {
564        cursor: Some("-5".to_string()),
565        ..Default::default()
566      }
567    );
568
569    let uuid = uuid::Uuid::now_v7();
570    let r = qs
571      .deserialize_str::<Query>(&format!("cursor={}", uuid.to_string()))
572      .unwrap();
573    assert_eq!(
574      r,
575      Query {
576        cursor: Some(uuid.to_string()),
577        ..Default::default()
578      }
579    );
580    assert_eq!(
581      Cursor::parse(&r.cursor.unwrap(), CursorType::Blob).unwrap(),
582      Cursor::Blob(uuid.into())
583    );
584
585    let blob = BASE64_URL_SAFE.encode(uuid.as_bytes());
586    assert_eq!(
587      qs.deserialize_str::<Query>(&format!("cursor={blob}"))
588        .unwrap(),
589      Query {
590        cursor: Some(blob),
591        ..Default::default()
592      }
593    );
594  }
595}