Skip to main content

modkit_odata/
lib.rs

1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
2pub mod builder;
3pub mod errors;
4pub mod filter;
5pub mod limits;
6pub(crate) mod odata_filters;
7mod odata_parse;
8pub mod page;
9pub mod pagination;
10pub mod problem_mapping;
11pub mod schema;
12
13pub use builder::QueryBuilder;
14pub use limits::ODataLimits;
15pub use page::{Page, PageInfo};
16pub use pagination::{normalize_filter_for_hash, short_filter_hash};
17pub use schema::{FieldRef, Schema};
18
19pub mod ast {
20    use bigdecimal::BigDecimal;
21    use chrono::{DateTime, NaiveDate, NaiveTime, Utc};
22    use uuid::Uuid;
23
24    #[derive(Clone, Debug)]
25    pub enum Expr {
26        And(Box<Expr>, Box<Expr>),
27        Or(Box<Expr>, Box<Expr>),
28        Not(Box<Expr>),
29        Compare(Box<Expr>, CompareOperator, Box<Expr>),
30        In(Box<Expr>, Vec<Expr>),
31        Function(String, Vec<Expr>),
32        Identifier(String),
33        Value(Value),
34    }
35
36    impl Expr {
37        /// Combine two expressions with AND: `expr1 and expr2`
38        ///
39        /// # Example
40        ///
41        /// ```rust,ignore
42        /// let filter = ID.eq(user_id).and(NAME.contains("john"));
43        /// ```
44        #[must_use]
45        pub fn and(self, other: Expr) -> Expr {
46            Expr::And(Box::new(self), Box::new(other))
47        }
48
49        /// Combine two expressions with OR: `expr1 or expr2`
50        #[must_use]
51        pub fn or(self, other: Expr) -> Expr {
52            Expr::Or(Box::new(self), Box::new(other))
53        }
54
55        /// Negate an expression: `not expr`
56        #[must_use]
57        #[allow(clippy::should_implement_trait)]
58        pub fn not(self) -> Expr {
59            !self
60        }
61    }
62
63    impl std::ops::Not for Expr {
64        type Output = Expr;
65
66        fn not(self) -> Self::Output {
67            Expr::Not(Box::new(self))
68        }
69    }
70
71    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
72    pub enum CompareOperator {
73        Eq,
74        Ne,
75        Gt,
76        Ge,
77        Lt,
78        Le,
79    }
80
81    #[derive(Clone, Debug)]
82    pub enum Value {
83        Null,
84        Bool(bool),
85        Number(BigDecimal),
86        Uuid(Uuid),
87        DateTime(DateTime<Utc>),
88        Date(NaiveDate),
89        Time(NaiveTime),
90        String(String),
91    }
92
93    impl std::fmt::Display for Value {
94        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95            match self {
96                Value::Null => write!(f, "null"),
97                Value::Bool(_) => write!(f, "bool"),
98                Value::Number(_) => write!(f, "number"),
99                Value::Uuid(_) => write!(f, "uuid"),
100                Value::DateTime(_) => write!(f, "datetime"),
101                Value::Date(_) => write!(f, "date"),
102                Value::Time(_) => write!(f, "time"),
103                Value::String(_) => write!(f, "string"),
104            }
105        }
106    }
107}
108
109// Ordering primitives
110#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
111pub enum SortDir {
112    #[serde(rename = "asc")]
113    Asc,
114    #[serde(rename = "desc")]
115    Desc,
116}
117
118impl SortDir {
119    /// Reverse the sort direction (Asc <-> Desc)
120    #[must_use]
121    pub fn reverse(self) -> Self {
122        match self {
123            SortDir::Asc => SortDir::Desc,
124            SortDir::Desc => SortDir::Asc,
125        }
126    }
127}
128
129#[derive(Clone, Debug)]
130pub struct OrderKey {
131    pub field: String,
132    pub dir: SortDir,
133}
134
135#[derive(Clone, Debug, Default)]
136#[must_use]
137pub struct ODataOrderBy(pub Vec<OrderKey>);
138
139impl ODataOrderBy {
140    pub fn empty() -> Self {
141        Self(vec![])
142    }
143
144    #[must_use]
145    pub fn is_empty(&self) -> bool {
146        self.0.is_empty()
147    }
148
149    /// Render as "+f1,-f2" for cursor.s
150    #[must_use]
151    pub fn to_signed_tokens(&self) -> String {
152        self.0
153            .iter()
154            .map(|k| {
155                if matches!(k.dir, SortDir::Asc) {
156                    format!("+{}", k.field)
157                } else {
158                    format!("-{}", k.field)
159                }
160            })
161            .collect::<Vec<_>>()
162            .join(",")
163    }
164
165    /// Parse signed tokens back to `ODataOrderBy` (e.g. "+a,-b" -> `ODataOrderBy`)
166    /// Returns Error for stricter validation used in cursor processing
167    ///
168    /// # Errors
169    /// Returns `Error::InvalidOrderByField` if the input is empty or contains invalid field names.
170    pub fn from_signed_tokens(signed: &str) -> Result<Self, Error> {
171        let mut out = Vec::new();
172        for seg in signed.split(',') {
173            let seg = seg.trim();
174            if seg.is_empty() {
175                continue;
176            }
177            let (dir, name) = match seg.as_bytes()[0] {
178                b'+' => (SortDir::Asc, &seg[1..]),
179                b'-' => (SortDir::Desc, &seg[1..]),
180                _ => (SortDir::Asc, seg), // default '+'
181            };
182            if name.is_empty() {
183                return Err(Error::InvalidOrderByField(seg.to_owned()));
184            }
185            out.push(OrderKey {
186                field: name.to_owned(),
187                dir,
188            });
189        }
190        if out.is_empty() {
191            return Err(Error::InvalidOrderByField("empty order".into()));
192        }
193        Ok(ODataOrderBy(out))
194    }
195
196    /// Check equality against signed token list (e.g. "+a,-b")
197    #[must_use]
198    pub fn equals_signed_tokens(&self, signed: &str) -> bool {
199        let parse = |t: &str| -> Option<(String, SortDir)> {
200            let t = t.trim();
201            if t.is_empty() {
202                return None;
203            }
204            let (dir, name) = match t.as_bytes()[0] {
205                b'+' => (SortDir::Asc, &t[1..]),
206                b'-' => (SortDir::Desc, &t[1..]),
207                _ => (SortDir::Asc, t),
208            };
209            if name.is_empty() {
210                return None;
211            }
212            Some((name.to_owned(), dir))
213        };
214        let theirs: Vec<_> = signed.split(',').filter_map(parse).collect();
215        if theirs.len() != self.0.len() {
216            return false;
217        }
218        self.0
219            .iter()
220            .zip(theirs.iter())
221            .all(|(a, (n, d))| a.field == *n && a.dir == *d)
222    }
223
224    /// Append tiebreaker if missing
225    pub fn ensure_tiebreaker(mut self, tiebreaker: &str, dir: SortDir) -> Self {
226        if !self.0.iter().any(|k| k.field == tiebreaker) {
227            self.0.push(OrderKey {
228                field: tiebreaker.to_owned(),
229                dir,
230            });
231        }
232        self
233    }
234
235    /// Reverse all sort directions (for backward pagination)
236    pub fn reverse_directions(mut self) -> Self {
237        for key in &mut self.0 {
238            key.dir = key.dir.reverse();
239        }
240        self
241    }
242}
243
244// Display trait for human-readable orderby representation
245impl std::fmt::Display for ODataOrderBy {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        if self.0.is_empty() {
248            return write!(f, "(none)");
249        }
250
251        let formatted: Vec<String> = self
252            .0
253            .iter()
254            .map(|key| {
255                let dir_str = match key.dir {
256                    SortDir::Asc => "asc",
257                    SortDir::Desc => "desc",
258                };
259                format!("{} {}", key.field, dir_str)
260            })
261            .collect();
262
263        write!(f, "{}", formatted.join(", "))
264    }
265}
266
267/// Unified error type for all `OData` operations
268///
269/// This centralizes all OData-related errors including parsing, validation,
270/// pagination, and cursor operations into a single error type using thiserror.
271///
272/// ## HTTP Mapping
273///
274/// These errors map to RFC 9457 Problem responses via the catalog in `modkit`:
275/// - `InvalidFilter` → 422 `gts...~hx.odata.errors.invalid_filter.v1`
276/// - `InvalidOrderByField` → 422 `gts...~hx.odata.errors.invalid_orderby.v1`
277/// - Cursor errors → 422 `gts...~hx.odata.errors.invalid_cursor.v1`
278#[derive(thiserror::Error, Debug, Clone)]
279pub enum Error {
280    // Filter parsing and validation errors
281    #[error("invalid $filter: {0}")]
282    InvalidFilter(String),
283
284    // OrderBy parsing and validation errors
285    #[error("unsupported $orderby field: {0}")]
286    InvalidOrderByField(String),
287
288    // Pagination and cursor errors
289    #[error("ORDER_MISMATCH")]
290    OrderMismatch,
291
292    #[error("FILTER_MISMATCH")]
293    FilterMismatch,
294
295    #[error("INVALID_CURSOR")]
296    InvalidCursor,
297
298    #[error("INVALID_LIMIT")]
299    InvalidLimit,
300
301    #[error("ORDER_WITH_CURSOR")]
302    OrderWithCursor,
303
304    // Cursor parsing errors (previously CursorError variants)
305    #[error("invalid cursor: invalid base64url encoding")]
306    CursorInvalidBase64,
307
308    #[error("invalid cursor: malformed JSON")]
309    CursorInvalidJson,
310
311    #[error("invalid cursor: unsupported version")]
312    CursorInvalidVersion,
313
314    #[error("invalid cursor: empty or invalid keys")]
315    CursorInvalidKeys,
316
317    #[error("invalid cursor: empty or invalid fields")]
318    CursorInvalidFields,
319
320    #[error("invalid cursor: invalid sort direction")]
321    CursorInvalidDirection,
322
323    // Database and low-level errors
324    #[error("database error: {0}")]
325    Db(String),
326
327    // Configuration errors
328    #[error("OData parsing unavailable: {0}")]
329    ParsingUnavailable(&'static str),
330}
331
332/// Validate cursor consistency against effective order and filter hash.
333///
334/// # Errors
335/// Returns `Error::OrderMismatch` if the cursor's sort order doesn't match the effective order.
336/// Returns `Error::FilterMismatch` if the cursor's filter hash doesn't match the effective filter.
337pub fn validate_cursor_against(
338    cursor: &CursorV1,
339    effective_order: &ODataOrderBy,
340    effective_filter_hash: Option<&str>,
341) -> Result<(), Error> {
342    if !effective_order.equals_signed_tokens(&cursor.s) {
343        return Err(Error::OrderMismatch);
344    }
345    if let (Some(h), Some(cf)) = (effective_filter_hash, cursor.f.as_deref())
346        && h != cf
347    {
348        return Err(Error::FilterMismatch);
349    }
350    Ok(())
351}
352
353// Cursor v1
354#[derive(Clone, Debug)]
355pub struct CursorV1 {
356    pub k: Vec<String>,
357    pub o: SortDir,
358    pub s: String,
359    pub f: Option<String>,
360    pub d: String, // Direction: "fwd" (forward) or "bwd" (backward)
361}
362
363impl CursorV1 {
364    /// Encode cursor to a base64url string.
365    ///
366    /// # Errors
367    /// Returns a JSON serialization error if encoding fails.
368    pub fn encode(&self) -> serde_json::Result<String> {
369        #[derive(serde::Serialize)]
370        struct Wire<'a> {
371            v: u8,
372            k: &'a [String],
373            o: &'a str,
374            s: &'a str,
375            #[serde(skip_serializing_if = "Option::is_none")]
376            f: &'a Option<String>,
377            d: &'a str,
378        }
379        let o = match self.o {
380            SortDir::Asc => "asc",
381            SortDir::Desc => "desc",
382        };
383        let w = Wire {
384            v: 1,
385            k: &self.k,
386            o,
387            s: &self.s,
388            f: &self.f,
389            d: &self.d,
390        };
391        serde_json::to_vec(&w).map(|x| base64_url::encode(&x))
392    }
393
394    /// Decode cursor from base64url token.
395    ///
396    /// # Errors
397    /// Returns `Error::CursorInvalidBase64` if base64 decoding fails.
398    /// Returns `Error::CursorInvalidJson` if JSON parsing fails.
399    /// Returns `Error::CursorInvalidVersion` if the version is unsupported.
400    /// Returns `Error::CursorInvalidDirection` if the direction field is invalid.
401    pub fn decode(token: &str) -> Result<Self, Error> {
402        #[derive(serde::Deserialize)]
403        struct Wire {
404            v: u8,
405            k: Vec<String>,
406            o: String,
407            s: String,
408            #[serde(default)]
409            f: Option<String>,
410            #[serde(default = "default_direction")]
411            d: String,
412        }
413
414        fn default_direction() -> String {
415            "fwd".to_owned()
416        }
417
418        let bytes = base64_url::decode(token).map_err(|_| Error::CursorInvalidBase64)?;
419        let w: Wire = serde_json::from_slice(&bytes).map_err(|_| Error::CursorInvalidJson)?;
420        if w.v != 1 {
421            return Err(Error::CursorInvalidVersion);
422        }
423        let o = match w.o.as_str() {
424            "asc" => SortDir::Asc,
425            "desc" => SortDir::Desc,
426            _ => return Err(Error::CursorInvalidDirection),
427        };
428        if w.k.is_empty() {
429            return Err(Error::CursorInvalidKeys);
430        }
431        if w.s.trim().is_empty() {
432            return Err(Error::CursorInvalidFields);
433        }
434        // Validate direction
435        if w.d != "fwd" && w.d != "bwd" {
436            return Err(Error::CursorInvalidDirection);
437        }
438        Ok(CursorV1 {
439            k: w.k,
440            o,
441            s: w.s,
442            f: w.f,
443            d: w.d,
444        })
445    }
446}
447
448// base64url helpers (no padding)
449mod base64_url {
450    use base64::Engine;
451
452    pub fn encode(bytes: &[u8]) -> String {
453        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
454    }
455
456    pub fn decode(s: &str) -> Result<Vec<u8>, base64::DecodeError> {
457        base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(s)
458    }
459}
460
461// The unified ODataQuery struct as single source of truth
462#[derive(Clone, Debug, Default)]
463#[must_use]
464pub struct ODataQuery {
465    pub filter: Option<Box<ast::Expr>>,
466    pub order: ODataOrderBy,
467    pub limit: Option<u64>,
468    pub cursor: Option<CursorV1>,
469    pub filter_hash: Option<String>,
470    pub select: Option<Vec<String>>,
471}
472
473impl ODataQuery {
474    pub fn new() -> Self {
475        Self::default()
476    }
477
478    pub fn with_filter(mut self, expr: ast::Expr) -> Self {
479        self.filter = Some(Box::new(expr));
480        self
481    }
482
483    pub fn with_order(mut self, order: ODataOrderBy) -> Self {
484        self.order = order;
485        self
486    }
487
488    pub fn with_limit(mut self, limit: u64) -> Self {
489        self.limit = Some(limit);
490        self
491    }
492
493    pub fn with_cursor(mut self, cursor: CursorV1) -> Self {
494        self.cursor = Some(cursor);
495        self
496    }
497
498    pub fn with_filter_hash(mut self, hash: String) -> Self {
499        self.filter_hash = Some(hash);
500        self
501    }
502
503    pub fn with_select(mut self, fields: Vec<String>) -> Self {
504        self.select = Some(fields);
505        self
506    }
507
508    /// Get filter as AST
509    #[must_use]
510    pub fn filter(&self) -> Option<&ast::Expr> {
511        self.filter.as_deref()
512    }
513
514    /// Check if filter is present
515    #[must_use]
516    pub fn has_filter(&self) -> bool {
517        self.filter.is_some()
518    }
519
520    /// Extract filter into AST
521    #[must_use]
522    pub fn into_filter(self) -> Option<ast::Expr> {
523        self.filter.map(|b| *b)
524    }
525
526    /// Check if field selection is present
527    #[must_use]
528    pub fn has_select(&self) -> bool {
529        self.select.is_some()
530    }
531
532    /// Get selected fields
533    #[must_use]
534    pub fn selected_fields(&self) -> Option<&[String]> {
535        self.select.as_deref()
536    }
537}
538
539impl From<Option<ast::Expr>> for ODataQuery {
540    fn from(opt: Option<ast::Expr>) -> Self {
541        match opt {
542            Some(e) => Self::default().with_filter(e),
543            None => Self::default(),
544        }
545    }
546}
547
548#[cfg(test)]
549mod odata_parse_tests;
550mod tests;
551
552mod convert_odata_filters {
553    use super::ast::{CompareOperator, Expr, Value};
554    use crate::odata_filters as od;
555
556    impl From<od::CompareOperator> for CompareOperator {
557        fn from(op: od::CompareOperator) -> Self {
558            use od::CompareOperator::{
559                Equal, GreaterOrEqual, GreaterThan, LessOrEqual, LessThan, NotEqual,
560            };
561            match op {
562                Equal => CompareOperator::Eq,
563                NotEqual => CompareOperator::Ne,
564                GreaterThan => CompareOperator::Gt,
565                GreaterOrEqual => CompareOperator::Ge,
566                LessThan => CompareOperator::Lt,
567                LessOrEqual => CompareOperator::Le,
568            }
569        }
570    }
571
572    impl From<od::Value> for Value {
573        fn from(v: od::Value) -> Self {
574            match v {
575                od::Value::Null => Value::Null,
576                od::Value::Bool(b) => Value::Bool(b),
577                od::Value::Number(n) => Value::Number(n),
578                od::Value::Uuid(u) => Value::Uuid(u),
579                od::Value::DateTime(dt) => Value::DateTime(dt),
580                od::Value::Date(d) => Value::Date(d),
581                od::Value::Time(t) => Value::Time(t),
582                od::Value::String(s) => Value::String(s),
583            }
584        }
585    }
586
587    impl From<od::Expr> for Expr {
588        fn from(e: od::Expr) -> Self {
589            use od::Expr::{And, Compare, Function, Identifier, In, Not, Or, Value};
590            match e {
591                And(a, b) => Expr::And(Box::new((*a).into()), Box::new((*b).into())),
592                Or(a, b) => Expr::Or(Box::new((*a).into()), Box::new((*b).into())),
593                Not(x) => Expr::Not(Box::new((*x).into())),
594                Compare(l, op, r) => {
595                    Expr::Compare(Box::new((*l).into()), op.into(), Box::new((*r).into()))
596                }
597                In(l, list) => Expr::In(
598                    Box::new((*l).into()),
599                    list.into_iter().map(Into::into).collect(),
600                ),
601                Function(n, args) => Expr::Function(n, args.into_iter().map(Into::into).collect()),
602                Identifier(s) => Expr::Identifier(s),
603                Value(v) => Expr::Value(v.into()),
604            }
605        }
606    }
607}
608
609/// Result of parsing a filter string, including both the AST and complexity metadata.
610#[derive(Clone, Debug)]
611pub struct ParsedFilter {
612    expr: ast::Expr,
613    node_count: usize,
614}
615
616impl ParsedFilter {
617    /// Get a reference to the parsed expression
618    #[must_use]
619    pub fn as_expr(&self) -> &ast::Expr {
620        &self.expr
621    }
622
623    /// Consume and extract the parsed expression
624    #[must_use]
625    pub fn into_expr(self) -> ast::Expr {
626        self.expr
627    }
628
629    /// Get the AST node count for budget enforcement
630    #[must_use]
631    pub fn node_count(&self) -> usize {
632        self.node_count
633    }
634}
635
636/// Parse a raw $filter string into internal AST with complexity metadata.
637///
638/// This function encapsulates the parsing logic and node counting.
639///
640/// # Errors
641/// - `Error::InvalidFilter` if the filter string is malformed or parsing fails
642///
643/// # Example
644/// ```ignore
645/// let result = parse_filter_string("name eq 'John' and age gt 18")?;
646/// if result.node_count() > MAX_NODES {
647///     return Err(Error::InvalidFilter("too complex".into()));
648/// }
649/// ```
650pub fn parse_filter_string(raw: &str) -> Result<ParsedFilter, Error> {
651    use crate::odata_filters as od;
652
653    /// Count nodes in AST for complexity budget enforcement.
654    fn count_ast_nodes(e: &od::Expr) -> usize {
655        use od::Expr::{And, Compare, Function, Identifier, In, Not, Or, Value};
656        match e {
657            Value(_) | Identifier(_) => 1,
658            Not(x) => 1 + count_ast_nodes(x),
659            And(a, b) | Or(a, b) | Compare(a, _, b) => 1 + count_ast_nodes(a) + count_ast_nodes(b),
660            In(a, list) => 1 + count_ast_nodes(a) + list.iter().map(count_ast_nodes).sum::<usize>(),
661            Function(_, args) => 1 + args.iter().map(count_ast_nodes).sum::<usize>(),
662        }
663    }
664
665    let ast_src = od::parse_str(raw).map_err(|e| Error::InvalidFilter(format!("{e}")))?;
666
667    let node_count = count_ast_nodes(&ast_src);
668    let expr: ast::Expr = ast_src.into();
669
670    Ok(ParsedFilter { expr, node_count })
671}