Skip to main content

modkit_odata/
lib.rs

1#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
2pub mod builder;
3pub mod errors;
4pub mod filter;
5pub mod limits;
6pub mod page;
7pub mod pagination;
8pub mod problem_mapping;
9pub mod schema;
10
11pub use builder::QueryBuilder;
12pub use limits::ODataLimits;
13pub use page::{Page, PageInfo};
14pub use pagination::{normalize_filter_for_hash, short_filter_hash};
15pub use schema::{FieldRef, Schema};
16
17pub mod ast {
18    use bigdecimal::BigDecimal;
19    use chrono::{DateTime, NaiveDate, NaiveTime, Utc};
20    use uuid::Uuid;
21
22    #[derive(Clone, Debug)]
23    pub enum Expr {
24        And(Box<Expr>, Box<Expr>),
25        Or(Box<Expr>, Box<Expr>),
26        Not(Box<Expr>),
27        Compare(Box<Expr>, CompareOperator, Box<Expr>),
28        In(Box<Expr>, Vec<Expr>),
29        Function(String, Vec<Expr>),
30        Identifier(String),
31        Value(Value),
32    }
33
34    impl Expr {
35        /// Combine two expressions with AND: `expr1 and expr2`
36        ///
37        /// # Example
38        ///
39        /// ```rust,ignore
40        /// let filter = ID.eq(user_id).and(NAME.contains("john"));
41        /// ```
42        #[must_use]
43        pub fn and(self, other: Expr) -> Expr {
44            Expr::And(Box::new(self), Box::new(other))
45        }
46
47        /// Combine two expressions with OR: `expr1 or expr2`
48        #[must_use]
49        pub fn or(self, other: Expr) -> Expr {
50            Expr::Or(Box::new(self), Box::new(other))
51        }
52
53        /// Negate an expression: `not expr`
54        #[must_use]
55        #[allow(clippy::should_implement_trait)]
56        pub fn not(self) -> Expr {
57            !self
58        }
59    }
60
61    impl std::ops::Not for Expr {
62        type Output = Expr;
63
64        fn not(self) -> Self::Output {
65            Expr::Not(Box::new(self))
66        }
67    }
68
69    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
70    pub enum CompareOperator {
71        Eq,
72        Ne,
73        Gt,
74        Ge,
75        Lt,
76        Le,
77    }
78
79    #[derive(Clone, Debug)]
80    pub enum Value {
81        Null,
82        Bool(bool),
83        Number(BigDecimal),
84        Uuid(Uuid),
85        DateTime(DateTime<Utc>),
86        Date(NaiveDate),
87        Time(NaiveTime),
88        String(String),
89    }
90
91    impl std::fmt::Display for Value {
92        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93            match self {
94                Value::Null => write!(f, "null"),
95                Value::Bool(_) => write!(f, "bool"),
96                Value::Number(_) => write!(f, "number"),
97                Value::Uuid(_) => write!(f, "uuid"),
98                Value::DateTime(_) => write!(f, "datetime"),
99                Value::Date(_) => write!(f, "date"),
100                Value::Time(_) => write!(f, "time"),
101                Value::String(_) => write!(f, "string"),
102            }
103        }
104    }
105}
106
107// Ordering primitives
108#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
109pub enum SortDir {
110    #[serde(rename = "asc")]
111    Asc,
112    #[serde(rename = "desc")]
113    Desc,
114}
115
116impl SortDir {
117    /// Reverse the sort direction (Asc <-> Desc)
118    #[must_use]
119    pub fn reverse(self) -> Self {
120        match self {
121            SortDir::Asc => SortDir::Desc,
122            SortDir::Desc => SortDir::Asc,
123        }
124    }
125}
126
127#[derive(Clone, Debug)]
128pub struct OrderKey {
129    pub field: String,
130    pub dir: SortDir,
131}
132
133#[derive(Clone, Debug, Default)]
134#[must_use]
135pub struct ODataOrderBy(pub Vec<OrderKey>);
136
137impl ODataOrderBy {
138    pub fn empty() -> Self {
139        Self(vec![])
140    }
141
142    #[must_use]
143    pub fn is_empty(&self) -> bool {
144        self.0.is_empty()
145    }
146
147    /// Render as "+f1,-f2" for cursor.s
148    #[must_use]
149    pub fn to_signed_tokens(&self) -> String {
150        self.0
151            .iter()
152            .map(|k| {
153                if matches!(k.dir, SortDir::Asc) {
154                    format!("+{}", k.field)
155                } else {
156                    format!("-{}", k.field)
157                }
158            })
159            .collect::<Vec<_>>()
160            .join(",")
161    }
162
163    /// Parse signed tokens back to `ODataOrderBy` (e.g. "+a,-b" -> `ODataOrderBy`)
164    /// Returns Error for stricter validation used in cursor processing
165    ///
166    /// # Errors
167    /// Returns `Error::InvalidOrderByField` if the input is empty or contains invalid field names.
168    pub fn from_signed_tokens(signed: &str) -> Result<Self, Error> {
169        let mut out = Vec::new();
170        for seg in signed.split(',') {
171            let seg = seg.trim();
172            if seg.is_empty() {
173                continue;
174            }
175            let (dir, name) = match seg.as_bytes()[0] {
176                b'+' => (SortDir::Asc, &seg[1..]),
177                b'-' => (SortDir::Desc, &seg[1..]),
178                _ => (SortDir::Asc, seg), // default '+'
179            };
180            if name.is_empty() {
181                return Err(Error::InvalidOrderByField(seg.to_owned()));
182            }
183            out.push(OrderKey {
184                field: name.to_owned(),
185                dir,
186            });
187        }
188        if out.is_empty() {
189            return Err(Error::InvalidOrderByField("empty order".into()));
190        }
191        Ok(ODataOrderBy(out))
192    }
193
194    /// Check equality against signed token list (e.g. "+a,-b")
195    #[must_use]
196    pub fn equals_signed_tokens(&self, signed: &str) -> bool {
197        let parse = |t: &str| -> Option<(String, SortDir)> {
198            let t = t.trim();
199            if t.is_empty() {
200                return None;
201            }
202            let (dir, name) = match t.as_bytes()[0] {
203                b'+' => (SortDir::Asc, &t[1..]),
204                b'-' => (SortDir::Desc, &t[1..]),
205                _ => (SortDir::Asc, t),
206            };
207            if name.is_empty() {
208                return None;
209            }
210            Some((name.to_owned(), dir))
211        };
212        let theirs: Vec<_> = signed.split(',').filter_map(parse).collect();
213        if theirs.len() != self.0.len() {
214            return false;
215        }
216        self.0
217            .iter()
218            .zip(theirs.iter())
219            .all(|(a, (n, d))| a.field == *n && a.dir == *d)
220    }
221
222    /// Append tiebreaker if missing
223    pub fn ensure_tiebreaker(mut self, tiebreaker: &str, dir: SortDir) -> Self {
224        if !self.0.iter().any(|k| k.field == tiebreaker) {
225            self.0.push(OrderKey {
226                field: tiebreaker.to_owned(),
227                dir,
228            });
229        }
230        self
231    }
232
233    /// Reverse all sort directions (for backward pagination)
234    pub fn reverse_directions(mut self) -> Self {
235        for key in &mut self.0 {
236            key.dir = key.dir.reverse();
237        }
238        self
239    }
240}
241
242// Display trait for human-readable orderby representation
243impl std::fmt::Display for ODataOrderBy {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        if self.0.is_empty() {
246            return write!(f, "(none)");
247        }
248
249        let formatted: Vec<String> = self
250            .0
251            .iter()
252            .map(|key| {
253                let dir_str = match key.dir {
254                    SortDir::Asc => "asc",
255                    SortDir::Desc => "desc",
256                };
257                format!("{} {}", key.field, dir_str)
258            })
259            .collect();
260
261        write!(f, "{}", formatted.join(", "))
262    }
263}
264
265/// Unified error type for all `OData` operations
266///
267/// This centralizes all OData-related errors including parsing, validation,
268/// pagination, and cursor operations into a single error type using thiserror.
269///
270/// ## HTTP Mapping
271///
272/// These errors map to RFC 9457 Problem responses via the catalog in `modkit`:
273/// - `InvalidFilter` → 422 `gts...~hx.odata.errors.invalid_filter.v1`
274/// - `InvalidOrderByField` → 422 `gts...~hx.odata.errors.invalid_orderby.v1`
275/// - Cursor errors → 422 `gts...~hx.odata.errors.invalid_cursor.v1`
276#[derive(thiserror::Error, Debug, Clone)]
277pub enum Error {
278    // Filter parsing and validation errors
279    #[error("invalid $filter: {0}")]
280    InvalidFilter(String),
281
282    // OrderBy parsing and validation errors
283    #[error("unsupported $orderby field: {0}")]
284    InvalidOrderByField(String),
285
286    // Pagination and cursor errors
287    #[error("ORDER_MISMATCH")]
288    OrderMismatch,
289
290    #[error("FILTER_MISMATCH")]
291    FilterMismatch,
292
293    #[error("INVALID_CURSOR")]
294    InvalidCursor,
295
296    #[error("INVALID_LIMIT")]
297    InvalidLimit,
298
299    #[error("ORDER_WITH_CURSOR")]
300    OrderWithCursor,
301
302    // Cursor parsing errors (previously CursorError variants)
303    #[error("invalid cursor: invalid base64url encoding")]
304    CursorInvalidBase64,
305
306    #[error("invalid cursor: malformed JSON")]
307    CursorInvalidJson,
308
309    #[error("invalid cursor: unsupported version")]
310    CursorInvalidVersion,
311
312    #[error("invalid cursor: empty or invalid keys")]
313    CursorInvalidKeys,
314
315    #[error("invalid cursor: empty or invalid fields")]
316    CursorInvalidFields,
317
318    #[error("invalid cursor: invalid sort direction")]
319    CursorInvalidDirection,
320
321    // Database and low-level errors
322    #[error("database error: {0}")]
323    Db(String),
324
325    // Configuration errors
326    #[error("OData parsing unavailable: {0}")]
327    ParsingUnavailable(&'static str),
328}
329
330/// Validate cursor consistency against effective order and filter hash.
331///
332/// # Errors
333/// Returns `Error::OrderMismatch` if the cursor's sort order doesn't match the effective order.
334/// Returns `Error::FilterMismatch` if the cursor's filter hash doesn't match the effective filter.
335pub fn validate_cursor_against(
336    cursor: &CursorV1,
337    effective_order: &ODataOrderBy,
338    effective_filter_hash: Option<&str>,
339) -> Result<(), Error> {
340    if !effective_order.equals_signed_tokens(&cursor.s) {
341        return Err(Error::OrderMismatch);
342    }
343    if let (Some(h), Some(cf)) = (effective_filter_hash, cursor.f.as_deref())
344        && h != cf
345    {
346        return Err(Error::FilterMismatch);
347    }
348    Ok(())
349}
350
351// Cursor v1
352#[derive(Clone, Debug)]
353pub struct CursorV1 {
354    pub k: Vec<String>,
355    pub o: SortDir,
356    pub s: String,
357    pub f: Option<String>,
358    pub d: String, // Direction: "fwd" (forward) or "bwd" (backward)
359}
360
361impl CursorV1 {
362    /// Encode cursor to a base64url string.
363    ///
364    /// # Errors
365    /// Returns a JSON serialization error if encoding fails.
366    pub fn encode(&self) -> serde_json::Result<String> {
367        #[derive(serde::Serialize)]
368        struct Wire<'a> {
369            v: u8,
370            k: &'a [String],
371            o: &'a str,
372            s: &'a str,
373            #[serde(skip_serializing_if = "Option::is_none")]
374            f: &'a Option<String>,
375            d: &'a str,
376        }
377        let o = match self.o {
378            SortDir::Asc => "asc",
379            SortDir::Desc => "desc",
380        };
381        let w = Wire {
382            v: 1,
383            k: &self.k,
384            o,
385            s: &self.s,
386            f: &self.f,
387            d: &self.d,
388        };
389        serde_json::to_vec(&w).map(|x| base64_url::encode(&x))
390    }
391
392    /// Decode cursor from base64url token.
393    ///
394    /// # Errors
395    /// Returns `Error::CursorInvalidBase64` if base64 decoding fails.
396    /// Returns `Error::CursorInvalidJson` if JSON parsing fails.
397    /// Returns `Error::CursorInvalidVersion` if the version is unsupported.
398    /// Returns `Error::CursorInvalidDirection` if the direction field is invalid.
399    pub fn decode(token: &str) -> Result<Self, Error> {
400        #[derive(serde::Deserialize)]
401        struct Wire {
402            v: u8,
403            k: Vec<String>,
404            o: String,
405            s: String,
406            #[serde(default)]
407            f: Option<String>,
408            #[serde(default = "default_direction")]
409            d: String,
410        }
411
412        fn default_direction() -> String {
413            "fwd".to_owned()
414        }
415
416        let bytes = base64_url::decode(token).map_err(|_| Error::CursorInvalidBase64)?;
417        let w: Wire = serde_json::from_slice(&bytes).map_err(|_| Error::CursorInvalidJson)?;
418        if w.v != 1 {
419            return Err(Error::CursorInvalidVersion);
420        }
421        let o = match w.o.as_str() {
422            "asc" => SortDir::Asc,
423            "desc" => SortDir::Desc,
424            _ => return Err(Error::CursorInvalidDirection),
425        };
426        if w.k.is_empty() {
427            return Err(Error::CursorInvalidKeys);
428        }
429        if w.s.trim().is_empty() {
430            return Err(Error::CursorInvalidFields);
431        }
432        // Validate direction
433        if w.d != "fwd" && w.d != "bwd" {
434            return Err(Error::CursorInvalidDirection);
435        }
436        Ok(CursorV1 {
437            k: w.k,
438            o,
439            s: w.s,
440            f: w.f,
441            d: w.d,
442        })
443    }
444}
445
446// base64url helpers (no padding)
447mod base64_url {
448    use base64::Engine;
449
450    pub fn encode(bytes: &[u8]) -> String {
451        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
452    }
453
454    pub fn decode(s: &str) -> Result<Vec<u8>, base64::DecodeError> {
455        base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(s)
456    }
457}
458
459// The unified ODataQuery struct as single source of truth
460#[derive(Clone, Debug, Default)]
461#[must_use]
462pub struct ODataQuery {
463    pub filter: Option<Box<ast::Expr>>,
464    pub order: ODataOrderBy,
465    pub limit: Option<u64>,
466    pub cursor: Option<CursorV1>,
467    pub filter_hash: Option<String>,
468    pub select: Option<Vec<String>>,
469}
470
471impl ODataQuery {
472    pub fn new() -> Self {
473        Self::default()
474    }
475
476    pub fn with_filter(mut self, expr: ast::Expr) -> Self {
477        self.filter = Some(Box::new(expr));
478        self
479    }
480
481    pub fn with_order(mut self, order: ODataOrderBy) -> Self {
482        self.order = order;
483        self
484    }
485
486    pub fn with_limit(mut self, limit: u64) -> Self {
487        self.limit = Some(limit);
488        self
489    }
490
491    pub fn with_cursor(mut self, cursor: CursorV1) -> Self {
492        self.cursor = Some(cursor);
493        self
494    }
495
496    pub fn with_filter_hash(mut self, hash: String) -> Self {
497        self.filter_hash = Some(hash);
498        self
499    }
500
501    pub fn with_select(mut self, fields: Vec<String>) -> Self {
502        self.select = Some(fields);
503        self
504    }
505
506    /// Get filter as AST
507    #[must_use]
508    pub fn filter(&self) -> Option<&ast::Expr> {
509        self.filter.as_deref()
510    }
511
512    /// Check if filter is present
513    #[must_use]
514    pub fn has_filter(&self) -> bool {
515        self.filter.is_some()
516    }
517
518    /// Extract filter into AST
519    #[must_use]
520    pub fn into_filter(self) -> Option<ast::Expr> {
521        self.filter.map(|b| *b)
522    }
523
524    /// Check if field selection is present
525    #[must_use]
526    pub fn has_select(&self) -> bool {
527        self.select.is_some()
528    }
529
530    /// Get selected fields
531    #[must_use]
532    pub fn selected_fields(&self) -> Option<&[String]> {
533        self.select.as_deref()
534    }
535}
536
537impl From<Option<ast::Expr>> for ODataQuery {
538    fn from(opt: Option<ast::Expr>) -> Self {
539        match opt {
540            Some(e) => Self::default().with_filter(e),
541            None => Self::default(),
542        }
543    }
544}
545
546mod tests;
547
548#[cfg(feature = "with-odata-params")]
549mod convert_odata_params {
550    use super::ast::{CompareOperator, Expr, Value};
551    use odata_params::filters as od;
552
553    impl From<od::CompareOperator> for CompareOperator {
554        fn from(op: od::CompareOperator) -> Self {
555            use od::CompareOperator::{
556                Equal, GreaterOrEqual, GreaterThan, LessOrEqual, LessThan, NotEqual,
557            };
558            match op {
559                Equal => CompareOperator::Eq,
560                NotEqual => CompareOperator::Ne,
561                GreaterThan => CompareOperator::Gt,
562                GreaterOrEqual => CompareOperator::Ge,
563                LessThan => CompareOperator::Lt,
564                LessOrEqual => CompareOperator::Le,
565            }
566        }
567    }
568
569    impl From<od::Value> for Value {
570        fn from(v: od::Value) -> Self {
571            match v {
572                od::Value::Null => Value::Null,
573                od::Value::Bool(b) => Value::Bool(b),
574                od::Value::Number(n) => Value::Number(n),
575                od::Value::Uuid(u) => Value::Uuid(u),
576                od::Value::DateTime(dt) => Value::DateTime(dt),
577                od::Value::Date(d) => Value::Date(d),
578                od::Value::Time(t) => Value::Time(t),
579                od::Value::String(s) => Value::String(s),
580            }
581        }
582    }
583
584    impl From<od::Expr> for Expr {
585        fn from(e: od::Expr) -> Self {
586            use od::Expr::{And, Compare, Function, Identifier, In, Not, Or, Value};
587            match e {
588                And(a, b) => Expr::And(Box::new((*a).into()), Box::new((*b).into())),
589                Or(a, b) => Expr::Or(Box::new((*a).into()), Box::new((*b).into())),
590                Not(x) => Expr::Not(Box::new((*x).into())),
591                Compare(l, op, r) => {
592                    Expr::Compare(Box::new((*l).into()), op.into(), Box::new((*r).into()))
593                }
594                In(l, list) => Expr::In(
595                    Box::new((*l).into()),
596                    list.into_iter().map(Into::into).collect(),
597                ),
598                Function(n, args) => Expr::Function(n, args.into_iter().map(Into::into).collect()),
599                Identifier(s) => Expr::Identifier(s),
600                Value(v) => Expr::Value(v.into()),
601            }
602        }
603    }
604}
605
606/// Result of parsing a filter string, including both the AST and complexity metadata.
607#[derive(Clone, Debug)]
608pub struct ParsedFilter {
609    expr: ast::Expr,
610    node_count: usize,
611}
612
613impl ParsedFilter {
614    /// Get a reference to the parsed expression
615    #[must_use]
616    pub fn as_expr(&self) -> &ast::Expr {
617        &self.expr
618    }
619
620    /// Consume and extract the parsed expression
621    #[must_use]
622    pub fn into_expr(self) -> ast::Expr {
623        self.expr
624    }
625
626    /// Get the AST node count for budget enforcement
627    #[must_use]
628    pub fn node_count(&self) -> usize {
629        self.node_count
630    }
631}
632
633/// Parse a raw $filter string into internal AST with complexity metadata.
634///
635/// This function encapsulates the parsing logic and node counting,
636/// abstracting away the underlying `odata_params` dependency.
637///
638/// # Errors
639/// - `Error::InvalidFilter` if the filter string is malformed or parsing fails
640/// - `Error::ParsingUnavailable` if the `with-odata-params` feature is disabled
641///
642/// # Example
643/// ```ignore
644/// let result = parse_filter_string("name eq 'John' and age gt 18")?;
645/// if result.node_count() > MAX_NODES {
646///     return Err(Error::InvalidFilter("too complex".into()));
647/// }
648/// ```
649#[cfg(feature = "with-odata-params")]
650pub fn parse_filter_string(raw: &str) -> Result<ParsedFilter, Error> {
651    use odata_params::filters as od;
652
653    /// Count nodes in `odata_params` AST for complexity budget enforcement.
654    fn count_ast_nodes(e: &od::Expr) -> usize {
655        use od::Expr::{And, Compare, Function, Identifier, In, Not, Or, Value};
656        match e {
657            Value(_) | Identifier(_) => 1,
658            Not(x) => 1 + count_ast_nodes(x),
659            And(a, b) | Or(a, b) | Compare(a, _, b) => 1 + count_ast_nodes(a) + count_ast_nodes(b),
660            In(a, list) => 1 + count_ast_nodes(a) + list.iter().map(count_ast_nodes).sum::<usize>(),
661            Function(_, args) => 1 + args.iter().map(count_ast_nodes).sum::<usize>(),
662        }
663    }
664
665    let ast_src = od::parse_str(raw).map_err(|e| Error::InvalidFilter(format!("{e:?}")))?;
666
667    let node_count = count_ast_nodes(&ast_src);
668    let expr: ast::Expr = ast_src.into();
669
670    Ok(ParsedFilter { expr, node_count })
671}
672
673/// Parse `OData` filter string.
674///
675/// This stub is compiled when the `with-odata-params` feature is disabled.
676///
677/// # Errors
678///
679/// Always returns `Error::ParsingUnavailable` because `OData` parsing
680/// support is not enabled in this build.
681#[cfg(not(feature = "with-odata-params"))]
682pub fn parse_filter_string(_raw: &str) -> Result<ParsedFilter, Error> {
683    Err(Error::ParsingUnavailable(
684        "OData filter parsing requires 'with-odata-params' feature",
685    ))
686}