Skip to main content

modkit_db/odata/
core.rs

1//! `OData` (filters) → `sea_orm::Condition` compiler (AST in, SQL out).
2//! Parsing belongs to API/gateway. This module only consumes `modkit_odata::ast::Expr`.
3
4use std::collections::HashMap;
5
6use bigdecimal::{BigDecimal, ToPrimitive};
7use chrono::{NaiveDate, NaiveTime, Utc};
8use modkit_odata::{CursorV1, Error as ODataError, ODataOrderBy, ODataQuery, SortDir, ast as core};
9use rust_decimal::Decimal;
10use sea_orm::{
11    ColumnTrait, Condition, ConnectionTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect,
12    sea_query::{Expr, Order},
13};
14use thiserror::Error;
15
16use modkit_odata::filter::FieldKind;
17
18use crate::odata::LimitCfg;
19
20/// Type alias for cursor extraction function to reduce type complexity
21type CursorExtractor<E> = fn(&<E as EntityTrait>::Model) -> String;
22
23#[derive(Clone)]
24pub struct Field<E: EntityTrait> {
25    pub col: E::Column,
26    pub kind: FieldKind,
27    pub to_string_for_cursor: Option<CursorExtractor<E>>,
28}
29
30#[derive(Clone)]
31#[must_use]
32pub struct FieldMap<E: EntityTrait> {
33    map: HashMap<String, Field<E>>,
34}
35
36impl<E: EntityTrait> Default for FieldMap<E> {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl<E: EntityTrait> FieldMap<E> {
43    pub fn new() -> Self {
44        Self {
45            map: HashMap::new(),
46        }
47    }
48    pub fn insert(mut self, api_name: impl Into<String>, col: E::Column, kind: FieldKind) -> Self {
49        self.map.insert(
50            api_name.into().to_lowercase(),
51            Field {
52                col,
53                kind,
54                to_string_for_cursor: None,
55            },
56        );
57        self
58    }
59
60    pub fn insert_with_extractor(
61        mut self,
62        api_name: impl Into<String>,
63        col: E::Column,
64        kind: FieldKind,
65        to_string_for_cursor: CursorExtractor<E>,
66    ) -> Self {
67        self.map.insert(
68            api_name.into().to_lowercase(),
69            Field {
70                col,
71                kind,
72                to_string_for_cursor: Some(to_string_for_cursor),
73            },
74        );
75        self
76    }
77
78    pub fn encode_model_key(&self, model: &E::Model, field_name: &str) -> Option<String> {
79        let f = self.get(field_name)?;
80        f.to_string_for_cursor.map(|f| f(model))
81    }
82    #[must_use]
83    pub fn get(&self, name: &str) -> Option<&Field<E>> {
84        self.map.get(&name.to_lowercase())
85    }
86}
87
88#[derive(Debug, Error, Clone)]
89pub enum ODataBuildError {
90    #[error("unknown field: {0}")]
91    UnknownField(String),
92
93    #[error("type mismatch: expected {expected:?}, got {got}")]
94    TypeMismatch {
95        expected: FieldKind,
96        got: &'static str,
97    },
98
99    #[error("unsupported operator: {0:?}")]
100    UnsupportedOp(core::CompareOperator),
101
102    #[error("unsupported function or args: {0}()")]
103    UnsupportedFn(String),
104
105    #[error("IN() list supports only literals")]
106    NonLiteralInList,
107
108    #[error("bare identifier not allowed: {0}")]
109    BareIdentifier(String),
110
111    #[error("bare literal not allowed")]
112    BareLiteral,
113
114    #[error("{0}")]
115    Other(&'static str),
116}
117pub type ODataBuildResult<T> = Result<T, ODataBuildError>;
118
119/* ---------- coercion helpers ---------- */
120
121fn bigdecimal_to_decimal(bd: &BigDecimal) -> ODataBuildResult<Decimal> {
122    // Robust conversion: preserve precision via string.
123    let s = bd.normalized().to_string();
124    Decimal::from_str_exact(&s)
125        .or_else(|_| s.parse::<Decimal>())
126        .map_err(|_| ODataBuildError::Other("invalid decimal"))
127}
128
129fn coerce(kind: FieldKind, v: &core::Value) -> ODataBuildResult<sea_orm::Value> {
130    use core::Value as V;
131    Ok(match (kind, v) {
132        (FieldKind::String, V::String(s)) => sea_orm::Value::String(Some(Box::new(s.clone()))),
133
134        (FieldKind::I64, V::Number(n)) => {
135            let i = n.to_i64().ok_or(ODataBuildError::TypeMismatch {
136                expected: FieldKind::I64,
137                got: "number",
138            })?;
139            sea_orm::Value::BigInt(Some(i))
140        }
141
142        (FieldKind::F64, V::Number(n)) => {
143            let f = n.to_f64().ok_or(ODataBuildError::TypeMismatch {
144                expected: FieldKind::F64,
145                got: "number",
146            })?;
147            sea_orm::Value::Double(Some(f))
148        }
149
150        // Box the Decimal
151        (FieldKind::Decimal, V::Number(n)) => {
152            sea_orm::Value::Decimal(Some(Box::new(bigdecimal_to_decimal(n)?)))
153        }
154
155        (FieldKind::Bool, V::Bool(b)) => sea_orm::Value::Bool(Some(*b)),
156
157        // Box the Uuid
158        (FieldKind::Uuid, V::Uuid(u)) => sea_orm::Value::Uuid(Some(Box::new(*u))),
159
160        // Box chrono types
161        (FieldKind::DateTimeUtc, V::DateTime(dt)) => {
162            sea_orm::Value::ChronoDateTimeUtc(Some(Box::new(*dt)))
163        }
164        (FieldKind::Date, V::Date(d)) => sea_orm::Value::ChronoDate(Some(Box::new(*d))),
165        (FieldKind::Time, V::Time(t)) => sea_orm::Value::ChronoTime(Some(Box::new(*t))),
166
167        (expected, V::Null) => {
168            return Err(ODataBuildError::TypeMismatch {
169                expected,
170                got: "null",
171            });
172        }
173        (expected, V::String(_)) => {
174            return Err(ODataBuildError::TypeMismatch {
175                expected,
176                got: "string",
177            });
178        }
179        (expected, V::Number(_)) => {
180            return Err(ODataBuildError::TypeMismatch {
181                expected,
182                got: "number",
183            });
184        }
185        (expected, V::Bool(_)) => {
186            return Err(ODataBuildError::TypeMismatch {
187                expected,
188                got: "bool",
189            });
190        }
191        (expected, V::Uuid(_)) => {
192            return Err(ODataBuildError::TypeMismatch {
193                expected,
194                got: "uuid",
195            });
196        }
197        (expected, V::DateTime(_)) => {
198            return Err(ODataBuildError::TypeMismatch {
199                expected,
200                got: "datetime",
201            });
202        }
203        (expected, V::Date(_)) => {
204            return Err(ODataBuildError::TypeMismatch {
205                expected,
206                got: "date",
207            });
208        }
209        (expected, V::Time(_)) => {
210            return Err(ODataBuildError::TypeMismatch {
211                expected,
212                got: "time",
213            });
214        }
215    })
216}
217
218fn coerce_many(kind: FieldKind, items: &[core::Expr]) -> ODataBuildResult<Vec<sea_orm::Value>> {
219    items
220        .iter()
221        .map(|e| match e {
222            core::Expr::Value(v) => coerce(kind, v),
223            _ => Err(ODataBuildError::NonLiteralInList),
224        })
225        .collect()
226}
227
228/* ---------- LIKE helpers ---------- */
229
230fn like_escape(s: &str) -> String {
231    let mut out = String::with_capacity(s.len());
232    for ch in s.chars() {
233        match ch {
234            '%' | '_' | '\\' => {
235                out.push('\\');
236                out.push(ch);
237            }
238            c => out.push(c),
239        }
240    }
241    out
242}
243fn like_contains(s: &str) -> String {
244    format!("%{}%", like_escape(s))
245}
246fn like_starts(s: &str) -> String {
247    format!("{}%", like_escape(s))
248}
249fn like_ends(s: &str) -> String {
250    format!("%{}", like_escape(s))
251}
252
253/* ---------- small guards ---------- */
254
255#[inline]
256fn ensure_string_field<E: EntityTrait>(f: &Field<E>, _field_name: &str) -> ODataBuildResult<()> {
257    if f.kind != FieldKind::String {
258        return Err(ODataBuildError::TypeMismatch {
259            expected: FieldKind::String,
260            got: "non-string field",
261        });
262    }
263    Ok(())
264}
265
266/* ---------- cursor value encoding/decoding ---------- */
267
268/// Parse a cursor value from string based on field kind
269pub fn parse_cursor_value(kind: FieldKind, s: &str) -> ODataBuildResult<sea_orm::Value> {
270    use sea_orm::Value as V;
271
272    let result = match kind {
273        FieldKind::String => V::String(Some(Box::new(s.to_owned()))),
274        FieldKind::I64 => {
275            let i = s
276                .parse::<i64>()
277                .map_err(|_| ODataBuildError::Other("invalid i64 in cursor"))?;
278            V::BigInt(Some(i))
279        }
280        FieldKind::F64 => {
281            let f = s
282                .parse::<f64>()
283                .map_err(|_| ODataBuildError::Other("invalid f64 in cursor"))?;
284            V::Double(Some(f))
285        }
286        FieldKind::Bool => {
287            let b = s
288                .parse::<bool>()
289                .map_err(|_| ODataBuildError::Other("invalid bool in cursor"))?;
290            V::Bool(Some(b))
291        }
292        FieldKind::Uuid => {
293            let u = s
294                .parse::<uuid::Uuid>()
295                .map_err(|_| ODataBuildError::Other("invalid uuid in cursor"))?;
296            V::Uuid(Some(Box::new(u)))
297        }
298        FieldKind::DateTimeUtc => {
299            let dt = chrono::DateTime::parse_from_rfc3339(s)
300                .map_err(|_| ODataBuildError::Other("invalid datetime in cursor"))?
301                .with_timezone(&Utc);
302            V::ChronoDateTimeUtc(Some(Box::new(dt)))
303        }
304        FieldKind::Date => {
305            let d = s
306                .parse::<NaiveDate>()
307                .map_err(|_| ODataBuildError::Other("invalid date in cursor"))?;
308            V::ChronoDate(Some(Box::new(d)))
309        }
310        FieldKind::Time => {
311            let t = s
312                .parse::<NaiveTime>()
313                .map_err(|_| ODataBuildError::Other("invalid time in cursor"))?;
314            V::ChronoTime(Some(Box::new(t)))
315        }
316        FieldKind::Decimal => {
317            let d = s
318                .parse::<Decimal>()
319                .map_err(|_| ODataBuildError::Other("invalid decimal in cursor"))?;
320            V::Decimal(Some(Box::new(d)))
321        }
322    };
323
324    Ok(result)
325}
326
327/* ---------- cursor predicate building ---------- */
328
329/// Build a cursor predicate for pagination.
330/// This builds the lexicographic OR-chain condition for cursor-based pagination.
331///
332/// For backward pagination (cursor.d == "bwd"), the comparison operators are reversed
333/// to fetch items before the cursor, but the order remains the same for display consistency.
334///
335/// # Errors
336/// Returns `ODataBuildError` if cursor keys don't match order fields or field resolution fails.
337pub fn build_cursor_predicate<E: EntityTrait>(
338    cursor: &CursorV1,
339    order: &ODataOrderBy,
340    fmap: &FieldMap<E>,
341) -> ODataBuildResult<Condition>
342where
343    E::Column: ColumnTrait + Copy,
344{
345    if cursor.k.len() != order.0.len() {
346        return Err(ODataBuildError::Other(
347            "cursor keys count mismatch with order fields",
348        ));
349    }
350
351    // Parse cursor values
352    let mut cursor_values = Vec::new();
353    for (i, key_str) in cursor.k.iter().enumerate() {
354        let order_key = &order.0[i];
355        let field = fmap
356            .get(&order_key.field)
357            .ok_or_else(|| ODataBuildError::UnknownField(order_key.field.clone()))?;
358        let value = parse_cursor_value(field.kind, key_str)?;
359        cursor_values.push((field, value, order_key.dir));
360    }
361
362    // Determine if we're going backward
363    let is_backward = cursor.d == "bwd";
364
365    // Build lexicographic condition
366    // Forward (fwd):
367    //   For ASC: (k0 > v0) OR (k0 = v0 AND k1 > v1) OR ...
368    //   For DESC: (k0 < v0) OR (k0 = v0 AND k1 < v1) OR ...
369    // Backward (bwd): Reverse the comparisons
370    //   For ASC: (k0 < v0) OR (k0 = v0 AND k1 < v1) OR ...
371    //   For DESC: (k0 > v0) OR (k0 = v0 AND k1 > v1) OR ...
372    let mut main_condition = Condition::any();
373
374    for i in 0..cursor_values.len() {
375        let mut prefix_condition = Condition::all();
376
377        // Add equality conditions for all previous fields
378        for (field, value, _) in cursor_values.iter().take(i) {
379            prefix_condition = prefix_condition.add(Expr::col(field.col).eq(value.clone()));
380        }
381
382        // Add the comparison condition for current field
383        let (field, value, dir) = &cursor_values[i];
384        let comparison = if is_backward {
385            // Backward: reverse the comparison
386            match dir {
387                SortDir::Asc => Expr::col(field.col).lt(value.clone()),
388                SortDir::Desc => Expr::col(field.col).gt(value.clone()),
389            }
390        } else {
391            // Forward: normal comparison
392            match dir {
393                SortDir::Asc => Expr::col(field.col).gt(value.clone()),
394                SortDir::Desc => Expr::col(field.col).lt(value.clone()),
395            }
396        };
397        prefix_condition = prefix_condition.add(comparison);
398
399        main_condition = main_condition.add(prefix_condition);
400    }
401
402    Ok(main_condition)
403}
404
405/* ---------- error mapping helpers ---------- */
406
407/// Resolve a field by name, converting `UnknownField` errors to `InvalidOrderByField`
408fn resolve_field<'a, E: EntityTrait>(
409    fld_map: &'a FieldMap<E>,
410    name: &str,
411) -> Result<&'a Field<E>, ODataError> {
412    fld_map
413        .get(name)
414        .ok_or_else(|| ODataError::InvalidOrderByField(name.to_owned()))
415}
416
417/* ---------- tiebreaker handling ---------- */
418
419/// Ensure a tiebreaker field is present in the order
420pub fn ensure_tiebreaker(order: ODataOrderBy, tiebreaker: &str, dir: SortDir) -> ODataOrderBy {
421    order.ensure_tiebreaker(tiebreaker, dir)
422}
423
424/* ---------- cursor building ---------- */
425
426/// Build a cursor from a model using the effective order and field map extractors.
427///
428/// # Errors
429/// Returns `ODataError::InvalidOrderByField` if a field cannot be encoded.
430pub fn build_cursor_for_model<E: EntityTrait>(
431    model: &E::Model,
432    order: &ODataOrderBy,
433    fmap: &FieldMap<E>,
434    primary_dir: SortDir,
435    filter_hash: Option<String>,
436    direction: &str, // "fwd" or "bwd"
437) -> Result<CursorV1, ODataError> {
438    let mut k = Vec::with_capacity(order.0.len());
439    for key in &order.0 {
440        let s = fmap
441            .encode_model_key(model, &key.field)
442            .ok_or_else(|| ODataError::InvalidOrderByField(key.field.clone()))?;
443        k.push(s);
444    }
445    Ok(CursorV1 {
446        k,
447        o: primary_dir,
448        s: order.to_signed_tokens(),
449        f: filter_hash,
450        d: direction.to_owned(),
451    })
452}
453
454/* ---------- Expr (AST) -> Condition ---------- */
455
456/// Convert an `OData` filter expression AST to a `SeaORM` Condition.
457///
458/// # Errors
459/// Returns `ODataBuildError` if the expression contains unknown fields or unsupported operations.
460pub fn expr_to_condition<E: EntityTrait>(
461    expr: &core::Expr,
462    fmap: &FieldMap<E>,
463) -> ODataBuildResult<Condition>
464where
465    E::Column: ColumnTrait + Copy,
466{
467    use core::CompareOperator as Op;
468    use core::Expr as X;
469
470    Ok(match expr {
471        X::And(a, b) => {
472            let left = expr_to_condition::<E>(a, fmap)?;
473            let right = expr_to_condition::<E>(b, fmap)?;
474            Condition::all().add(left).add(right) // AND
475        }
476        X::Or(a, b) => {
477            let left = expr_to_condition::<E>(a, fmap)?;
478            let right = expr_to_condition::<E>(b, fmap)?;
479            Condition::any().add(left).add(right) // OR
480        }
481        X::Not(x) => {
482            let inner = expr_to_condition::<E>(x, fmap)?;
483            Condition::all().add(inner).not()
484        }
485
486        // Identifier op Value
487        X::Compare(lhs, op, rhs) => {
488            let (name, rhs_val) = match (&**lhs, &**rhs) {
489                (X::Identifier(name), X::Value(val)) => (name, val),
490                (X::Identifier(_), X::Identifier(_)) => {
491                    return Err(ODataBuildError::Other(
492                        "field-to-field comparison is not supported",
493                    ));
494                }
495                _ => return Err(ODataBuildError::Other("unsupported comparison form")),
496            };
497            let field = fmap
498                .get(name)
499                .ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
500            let col = field.col;
501
502            // null handling
503            if matches!(rhs_val, core::Value::Null) {
504                return Ok(match op {
505                    Op::Eq => Condition::all().add(Expr::col(col).is_null()),
506                    Op::Ne => Condition::all().add(Expr::col(col).is_not_null()),
507                    _ => return Err(ODataBuildError::UnsupportedOp(*op)),
508                });
509            }
510
511            let value = coerce(field.kind, rhs_val)?;
512            let expr = match op {
513                Op::Eq => Expr::col(col).eq(value),
514                Op::Ne => Expr::col(col).ne(value),
515                Op::Gt => Expr::col(col).gt(value),
516                Op::Ge => Expr::col(col).gte(value),
517                Op::Lt => Expr::col(col).lt(value),
518                Op::Le => Expr::col(col).lte(value),
519            };
520            Condition::all().add(expr)
521        }
522
523        // Identifier IN (value, value, ...)
524        X::In(l, list) => {
525            let X::Identifier(name) = &**l else {
526                return Err(ODataBuildError::Other("left side of IN must be a field"));
527            };
528            let f = fmap
529                .get(name)
530                .ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
531            let col = f.col;
532            let vals = coerce_many(f.kind, list)?;
533            if vals.is_empty() {
534                // IN () → always false
535                Condition::all().add(Expr::cust("1=0"))
536            } else {
537                Condition::all().add(Expr::col(col).is_in(vals))
538            }
539        }
540
541        // Supported functions: contains/startswith/endswith
542        X::Function(fname, args) => {
543            let n = fname.to_ascii_lowercase();
544            match (n.as_str(), args.as_slice()) {
545                ("contains", [X::Identifier(name), X::Value(core::Value::String(s))]) => {
546                    let f = fmap
547                        .get(name)
548                        .ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
549                    ensure_string_field(f, name)?;
550                    Condition::all().add(Expr::col(f.col).like(like_contains(s)))
551                }
552                ("startswith", [X::Identifier(name), X::Value(core::Value::String(s))]) => {
553                    let f = fmap
554                        .get(name)
555                        .ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
556                    ensure_string_field(f, name)?;
557                    Condition::all().add(Expr::col(f.col).like(like_starts(s)))
558                }
559                ("endswith", [X::Identifier(name), X::Value(core::Value::String(s))]) => {
560                    let f = fmap
561                        .get(name)
562                        .ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
563                    ensure_string_field(f, name)?;
564                    Condition::all().add(Expr::col(f.col).like(like_ends(s)))
565                }
566                _ => return Err(ODataBuildError::UnsupportedFn(fname.clone())),
567            }
568        }
569
570        // Leaf forms are not valid WHERE by themselves
571        X::Identifier(name) => return Err(ODataBuildError::BareIdentifier(name.clone())),
572        X::Value(_) => return Err(ODataBuildError::BareLiteral),
573    })
574}
575
576/// Apply an optional `OData` filter (via wrapper) to a plain `SeaORM` Select<E>.
577///
578/// This extension does NOT parse the filter string — it only consumes a parsed AST
579/// (`modkit_odata::ast::Expr`) and translates it into a `sea_orm::Condition`.
580pub trait ODataExt<E: EntityTrait>: Sized {
581    /// Apply `OData` filter to the query.
582    ///
583    /// # Errors
584    /// Returns `ODataBuildError` if the filter contains unknown fields or invalid expressions.
585    fn apply_odata_filter(
586        self,
587        od_query: ODataQuery,
588        fld_map: &FieldMap<E>,
589    ) -> ODataBuildResult<Self>;
590}
591
592impl<E> ODataExt<E> for sea_orm::Select<E>
593where
594    E: EntityTrait,
595    E::Column: ColumnTrait + Copy,
596{
597    fn apply_odata_filter(
598        self,
599        od_query: ODataQuery,
600        fld_map: &FieldMap<E>,
601    ) -> ODataBuildResult<Self> {
602        match od_query.filter() {
603            Some(ast) => {
604                let cond = expr_to_condition::<E>(ast, fld_map)?;
605                Ok(self.filter(cond))
606            }
607            None => Ok(self),
608        }
609    }
610}
611
612/// Extension trait for applying cursor-based pagination
613pub trait CursorApplyExt<E: EntityTrait>: Sized {
614    /// Apply cursor-based forward pagination.
615    ///
616    /// # Errors
617    /// Returns `ODataBuildError` if cursor validation fails.
618    fn apply_cursor_forward(
619        self,
620        cursor: &CursorV1,
621        order: &ODataOrderBy,
622        fld_map: &FieldMap<E>,
623    ) -> ODataBuildResult<Self>;
624}
625
626impl<E> CursorApplyExt<E> for sea_orm::Select<E>
627where
628    E: EntityTrait,
629    E::Column: ColumnTrait + Copy,
630{
631    fn apply_cursor_forward(
632        self,
633        cursor: &CursorV1,
634        order: &ODataOrderBy,
635        fld_map: &FieldMap<E>,
636    ) -> ODataBuildResult<Self> {
637        let cond = build_cursor_predicate(cursor, order, fld_map)?;
638        Ok(self.filter(cond))
639    }
640}
641
642/// Extension trait for applying ordering (legacy version with `ODataBuildError`)
643pub trait ODataOrderExt<E: EntityTrait>: Sized {
644    /// Apply `OData` ordering to the query.
645    ///
646    /// # Errors
647    /// Returns `ODataBuildError` if an unknown field is referenced.
648    fn apply_odata_order(
649        self,
650        order: &ODataOrderBy,
651        fld_map: &FieldMap<E>,
652    ) -> ODataBuildResult<Self>;
653}
654
655impl<E> ODataOrderExt<E> for sea_orm::Select<E>
656where
657    E: EntityTrait,
658    E::Column: ColumnTrait + Copy,
659{
660    fn apply_odata_order(
661        self,
662        order: &ODataOrderBy,
663        fld_map: &FieldMap<E>,
664    ) -> ODataBuildResult<Self> {
665        let mut query = self;
666
667        for order_key in &order.0 {
668            let field = fld_map
669                .get(&order_key.field)
670                .ok_or_else(|| ODataBuildError::UnknownField(order_key.field.clone()))?;
671
672            let sea_order = match order_key.dir {
673                SortDir::Asc => Order::Asc,
674                SortDir::Desc => Order::Desc,
675            };
676
677            query = query.order_by(field.col, sea_order);
678        }
679
680        Ok(query)
681    }
682}
683
684/// Extension trait for applying ordering with centralized error handling
685pub trait ODataOrderPageExt<E: EntityTrait>: Sized {
686    /// Apply `OData` ordering with page-level error handling.
687    ///
688    /// # Errors
689    /// Returns `ODataError` if an unknown field is referenced.
690    fn apply_odata_order_page(
691        self,
692        order: &ODataOrderBy,
693        fld_map: &FieldMap<E>,
694    ) -> Result<Self, ODataError>;
695}
696
697impl<E> ODataOrderPageExt<E> for sea_orm::Select<E>
698where
699    E: EntityTrait,
700    E::Column: ColumnTrait + Copy,
701{
702    fn apply_odata_order_page(
703        self,
704        order: &ODataOrderBy,
705        fld_map: &FieldMap<E>,
706    ) -> Result<Self, ODataError> {
707        let mut query = self;
708
709        for order_key in &order.0 {
710            let field = resolve_field(fld_map, &order_key.field)?;
711
712            let sea_order = match order_key.dir {
713                SortDir::Asc => Order::Asc,
714                SortDir::Desc => Order::Desc,
715            };
716
717            query = query.order_by(field.col, sea_order);
718        }
719
720        Ok(query)
721    }
722}
723
724/// Extension trait for applying full `OData` query (filter + cursor + order)
725pub trait ODataQueryExt<E: EntityTrait>: Sized {
726    /// Apply full `OData` query including filter, cursor, and ordering.
727    ///
728    /// # Errors
729    /// Returns `ODataBuildError` if any part of the query application fails.
730    fn apply_odata_query(
731        self,
732        query: &ODataQuery,
733        fld_map: &FieldMap<E>,
734        tiebreaker: (&str, SortDir),
735    ) -> ODataBuildResult<Self>;
736}
737
738impl<E> ODataQueryExt<E> for sea_orm::Select<E>
739where
740    E: EntityTrait,
741    E::Column: ColumnTrait + Copy,
742{
743    fn apply_odata_query(
744        self,
745        query: &ODataQuery,
746        fld_map: &FieldMap<E>,
747        tiebreaker: (&str, SortDir),
748    ) -> ODataBuildResult<Self> {
749        let mut select = self;
750
751        if let Some(ast) = query.filter.as_deref() {
752            let cond = expr_to_condition::<E>(ast, fld_map)?;
753            select = select.filter(cond);
754        }
755
756        let effective_order = ensure_tiebreaker(query.order.clone(), tiebreaker.0, tiebreaker.1);
757
758        if let Some(cursor) = &query.cursor {
759            select = select.apply_cursor_forward(cursor, &effective_order, fld_map)?;
760        }
761
762        select = select.apply_odata_order(&effective_order, fld_map)?;
763
764        Ok(select)
765    }
766}
767
768/* ---------- pagination combiner ---------- */
769
770// Use unified pagination types from modkit-odata
771pub use modkit_odata::{Page, PageInfo};
772
773// Note: LimitCfg is imported at the top and re-exported from odata/mod.rs
774
775fn clamp_limit(req: Option<u64>, cfg: LimitCfg) -> u64 {
776    let mut l = req.unwrap_or(cfg.default);
777    if l == 0 {
778        l = 1;
779    }
780    if l > cfg.max {
781        l = cfg.max;
782    }
783    l
784}
785
786/// One-shot pagination combiner that handles filter → cursor predicate → order → overfetch/trim → build cursors.
787///
788/// # Errors
789/// Returns `ODataError` if filter application, cursor validation, or database query fails.
790pub async fn paginate_with_odata<E, D, F, C>(
791    select: sea_orm::Select<E>,
792    conn: &C,
793    q: &ODataQuery,
794    fmap: &FieldMap<E>,
795    tiebreaker: (&str, SortDir), // e.g. ("id", SortDir::Desc)
796    limit_cfg: LimitCfg,         // e.g. { default: 25, max: 1000 }
797    model_to_domain: F,
798) -> Result<Page<D>, ODataError>
799where
800    E: EntityTrait,
801    E::Column: ColumnTrait + Copy,
802    F: Fn(E::Model) -> D + Copy,
803    C: ConnectionTrait + Send + Sync,
804{
805    let limit = clamp_limit(q.limit, limit_cfg);
806    let fetch = limit + 1;
807
808    // Effective order derivation based on new policy
809    let effective_order = if let Some(cur) = &q.cursor {
810        // Derive order from the cursor's signed tokens
811        modkit_odata::ODataOrderBy::from_signed_tokens(&cur.s)
812            .map_err(|_| ODataError::InvalidCursor)?
813    } else {
814        // Use client order; ensure tiebreaker
815        q.order
816            .clone()
817            .ensure_tiebreaker(tiebreaker.0, tiebreaker.1)
818    };
819
820    // Validate cursor consistency (filter hash only) if cursor present
821    if let Some(cur) = &q.cursor
822        && let (Some(h), Some(cf)) = (q.filter_hash.as_deref(), cur.f.as_deref())
823        && h != cf
824    {
825        return Err(ODataError::FilterMismatch);
826    }
827
828    // Compose: filter → cursor predicate → order; apply limit+1 at the end
829    let mut s = select;
830
831    // Apply filter
832    if let Some(ast) = q.filter.as_deref() {
833        s = s.filter(
834            expr_to_condition::<E>(ast, fmap)
835                .map_err(|e| ODataError::InvalidFilter(e.to_string()))?,
836        );
837    }
838
839    // Check if we're paginating backward
840    let is_backward = q.cursor.as_ref().is_some_and(|c| c.d == "bwd");
841
842    // Apply cursor if present
843    if let Some(cursor) = &q.cursor {
844        s = s.filter(
845            build_cursor_predicate(cursor, &effective_order, fmap)
846                .map_err(|_| ODataError::InvalidCursor)?,
847        );
848    }
849
850    // Apply order (reverse it for backward pagination)
851    let query_order = if is_backward {
852        effective_order.clone().reverse_directions()
853    } else {
854        effective_order.clone()
855    };
856    s = s.apply_odata_order_page(&query_order, fmap)?;
857
858    // Apply limit
859    s = s.limit(fetch);
860
861    #[allow(clippy::disallowed_methods)]
862    let mut rows = s
863        .all(conn)
864        .await
865        .map_err(|e| ODataError::Db(e.to_string()))?;
866
867    let has_more = (rows.len() as u64) > limit;
868
869    // For backward pagination with reversed ORDER BY:
870    // - DB returns items in opposite order
871    // - We fetch limit+1 to detect has_more
872    // - We need to: 1) trim, 2) reverse back to original order
873    if is_backward {
874        // Remove the extra item (furthest back in time, which is at the END after reversed query)
875        if has_more {
876            rows.pop();
877        }
878        // Reverse to restore original display order
879        rows.reverse();
880    } else if has_more {
881        // Forward pagination: just truncate the end
882        rows.truncate(usize::try_from(limit).unwrap_or(usize::MAX));
883    }
884
885    // Build cursors
886    // After all the reversals, rows are in the display order (DESC)
887    // - rows.first() = newest item
888    // - rows.last() = oldest item
889    //
890    // For backward pagination:
891    //   - has_more means "more items backward" (older)
892    //   - next_cursor should always be present (we came from forward)
893    //   - prev_cursor based on has_more
894    // For forward pagination:
895    //   - has_more means "more items forward" (older in DESC)
896    //   - next_cursor based on has_more
897    //   - prev_cursor always present (unless at start)
898
899    let next_cursor = if is_backward {
900        // Going backward: always have items forward (unless this was the initial query)
901        // Build cursor from last item to go forward
902        build_cursor(&rows, &effective_order, fmap, tiebreaker, q, true, "fwd")?
903    } else if has_more {
904        // Going forward: only have more if has_more is true
905        build_cursor(&rows, &effective_order, fmap, tiebreaker, q, true, "fwd")?
906    } else {
907        None
908    };
909
910    let prev_cursor = if is_backward {
911        // Going backward: only have more backward if has_more is true
912        if has_more {
913            build_cursor(&rows, &effective_order, fmap, tiebreaker, q, false, "bwd")?
914        } else {
915            None
916        }
917    } else if q.cursor.is_some() {
918        // Going forward: have items backward only if this is NOT the initial query
919        // If q.cursor is None, we're at the start of the dataset
920        build_cursor(&rows, &effective_order, fmap, tiebreaker, q, false, "bwd")?
921    } else {
922        None
923    };
924
925    let items = rows.into_iter().map(model_to_domain).collect();
926
927    Ok(Page {
928        items,
929        page_info: PageInfo {
930            next_cursor,
931            prev_cursor,
932            limit,
933        },
934    })
935}
936
937fn build_cursor<E: EntityTrait>(
938    rows: &[E::Model],
939    effective_order: &ODataOrderBy,
940    fmap: &FieldMap<E>,
941    tiebreaker: (&str, SortDir),
942    q: &ODataQuery,
943    last: bool,
944    direction: &str,
945) -> Result<Option<String>, ODataError> {
946    if last { rows.last() } else { rows.first() }
947        .map(|m| {
948            build_cursor_for_model::<E>(
949                m,
950                effective_order,
951                fmap,
952                tiebreaker.1,
953                q.filter_hash.clone(),
954                direction,
955            )
956            .and_then(|c| c.encode().map_err(|_| ODataError::InvalidCursor))
957        })
958        .transpose()
959}