use std::collections::HashMap;
use bigdecimal::{BigDecimal, ToPrimitive};
use chrono::{NaiveDate, NaiveTime, Utc};
use modkit_odata::{CursorV1, Error as ODataError, ODataOrderBy, ODataQuery, SortDir, ast as core};
use rust_decimal::Decimal;
use sea_orm::{
ColumnTrait, Condition, EntityTrait, QueryFilter, QueryOrder, QuerySelect,
sea_query::{Expr, Order},
};
use thiserror::Error;
use modkit_odata::filter::FieldKind;
use crate::odata::LimitCfg;
use crate::secure::{DBRunner, DBRunnerInternal, SeaOrmRunner};
type CursorExtractor<E> = fn(&<E as EntityTrait>::Model) -> String;
#[derive(Clone)]
pub struct Field<E: EntityTrait> {
pub col: E::Column,
pub kind: FieldKind,
pub to_string_for_cursor: Option<CursorExtractor<E>>,
}
#[derive(Clone)]
#[must_use]
pub struct FieldMap<E: EntityTrait> {
map: HashMap<String, Field<E>>,
}
impl<E: EntityTrait> Default for FieldMap<E> {
fn default() -> Self {
Self::new()
}
}
impl<E: EntityTrait> FieldMap<E> {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn insert(mut self, api_name: impl Into<String>, col: E::Column, kind: FieldKind) -> Self {
self.map.insert(
api_name.into().to_lowercase(),
Field {
col,
kind,
to_string_for_cursor: None,
},
);
self
}
pub fn insert_with_extractor(
mut self,
api_name: impl Into<String>,
col: E::Column,
kind: FieldKind,
to_string_for_cursor: CursorExtractor<E>,
) -> Self {
self.map.insert(
api_name.into().to_lowercase(),
Field {
col,
kind,
to_string_for_cursor: Some(to_string_for_cursor),
},
);
self
}
pub fn encode_model_key(&self, model: &E::Model, field_name: &str) -> Option<String> {
let f = self.get(field_name)?;
f.to_string_for_cursor.map(|f| f(model))
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&Field<E>> {
self.map.get(&name.to_lowercase())
}
}
#[derive(Debug, Error, Clone)]
pub enum ODataBuildError {
#[error("unknown field: {0}")]
UnknownField(String),
#[error("type mismatch: expected {expected:?}, got {got}")]
TypeMismatch {
expected: FieldKind,
got: &'static str,
},
#[error("unsupported operator: {0:?}")]
UnsupportedOp(core::CompareOperator),
#[error("unsupported function or args: {0}()")]
UnsupportedFn(String),
#[error("IN() list supports only literals")]
NonLiteralInList,
#[error("bare identifier not allowed: {0}")]
BareIdentifier(String),
#[error("bare literal not allowed")]
BareLiteral,
#[error("{0}")]
Other(&'static str),
}
pub type ODataBuildResult<T> = Result<T, ODataBuildError>;
fn bigdecimal_to_decimal(bd: &BigDecimal) -> ODataBuildResult<Decimal> {
let s = bd.normalized().to_string();
Decimal::from_str_exact(&s)
.or_else(|_| s.parse::<Decimal>())
.map_err(|_| ODataBuildError::Other("invalid decimal"))
}
fn coerce(kind: FieldKind, v: &core::Value) -> ODataBuildResult<sea_orm::Value> {
use core::Value as V;
Ok(match (kind, v) {
(FieldKind::String, V::String(s)) => sea_orm::Value::String(Some(Box::new(s.clone()))),
(FieldKind::I64, V::Number(n)) => {
let i = n.to_i64().ok_or(ODataBuildError::TypeMismatch {
expected: FieldKind::I64,
got: "number",
})?;
sea_orm::Value::BigInt(Some(i))
}
(FieldKind::F64, V::Number(n)) => {
let f = n.to_f64().ok_or(ODataBuildError::TypeMismatch {
expected: FieldKind::F64,
got: "number",
})?;
sea_orm::Value::Double(Some(f))
}
(FieldKind::Decimal, V::Number(n)) => {
sea_orm::Value::Decimal(Some(Box::new(bigdecimal_to_decimal(n)?)))
}
(FieldKind::Bool, V::Bool(b)) => sea_orm::Value::Bool(Some(*b)),
(FieldKind::Uuid, V::Uuid(u)) => sea_orm::Value::Uuid(Some(Box::new(*u))),
(FieldKind::DateTimeUtc, V::DateTime(dt)) => {
sea_orm::Value::ChronoDateTimeUtc(Some(Box::new(*dt)))
}
(FieldKind::Date, V::Date(d)) => sea_orm::Value::ChronoDate(Some(Box::new(*d))),
(FieldKind::Time, V::Time(t)) => sea_orm::Value::ChronoTime(Some(Box::new(*t))),
(expected, V::Null) => {
return Err(ODataBuildError::TypeMismatch {
expected,
got: "null",
});
}
(expected, V::String(_)) => {
return Err(ODataBuildError::TypeMismatch {
expected,
got: "string",
});
}
(expected, V::Number(_)) => {
return Err(ODataBuildError::TypeMismatch {
expected,
got: "number",
});
}
(expected, V::Bool(_)) => {
return Err(ODataBuildError::TypeMismatch {
expected,
got: "bool",
});
}
(expected, V::Uuid(_)) => {
return Err(ODataBuildError::TypeMismatch {
expected,
got: "uuid",
});
}
(expected, V::DateTime(_)) => {
return Err(ODataBuildError::TypeMismatch {
expected,
got: "datetime",
});
}
(expected, V::Date(_)) => {
return Err(ODataBuildError::TypeMismatch {
expected,
got: "date",
});
}
(expected, V::Time(_)) => {
return Err(ODataBuildError::TypeMismatch {
expected,
got: "time",
});
}
})
}
fn coerce_many(kind: FieldKind, items: &[core::Expr]) -> ODataBuildResult<Vec<sea_orm::Value>> {
items
.iter()
.map(|e| match e {
core::Expr::Value(v) => coerce(kind, v),
_ => Err(ODataBuildError::NonLiteralInList),
})
.collect()
}
fn like_escape(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'%' | '_' | '\\' => {
out.push('\\');
out.push(ch);
}
c => out.push(c),
}
}
out
}
fn like_contains(s: &str) -> String {
format!("%{}%", like_escape(s))
}
fn like_starts(s: &str) -> String {
format!("{}%", like_escape(s))
}
fn like_ends(s: &str) -> String {
format!("%{}", like_escape(s))
}
#[inline]
fn ensure_string_field<E: EntityTrait>(f: &Field<E>, _field_name: &str) -> ODataBuildResult<()> {
if f.kind != FieldKind::String {
return Err(ODataBuildError::TypeMismatch {
expected: FieldKind::String,
got: "non-string field",
});
}
Ok(())
}
pub fn parse_cursor_value(kind: FieldKind, s: &str) -> ODataBuildResult<sea_orm::Value> {
use sea_orm::Value as V;
let result = match kind {
FieldKind::String => V::String(Some(Box::new(s.to_owned()))),
FieldKind::I64 => {
let i = s
.parse::<i64>()
.map_err(|_| ODataBuildError::Other("invalid i64 in cursor"))?;
V::BigInt(Some(i))
}
FieldKind::F64 => {
let f = s
.parse::<f64>()
.map_err(|_| ODataBuildError::Other("invalid f64 in cursor"))?;
V::Double(Some(f))
}
FieldKind::Bool => {
let b = s
.parse::<bool>()
.map_err(|_| ODataBuildError::Other("invalid bool in cursor"))?;
V::Bool(Some(b))
}
FieldKind::Uuid => {
let u = s
.parse::<uuid::Uuid>()
.map_err(|_| ODataBuildError::Other("invalid uuid in cursor"))?;
V::Uuid(Some(Box::new(u)))
}
FieldKind::DateTimeUtc => {
let dt = chrono::DateTime::parse_from_rfc3339(s)
.map_err(|_| ODataBuildError::Other("invalid datetime in cursor"))?
.with_timezone(&Utc);
V::ChronoDateTimeUtc(Some(Box::new(dt)))
}
FieldKind::Date => {
let d = s
.parse::<NaiveDate>()
.map_err(|_| ODataBuildError::Other("invalid date in cursor"))?;
V::ChronoDate(Some(Box::new(d)))
}
FieldKind::Time => {
let t = s
.parse::<NaiveTime>()
.map_err(|_| ODataBuildError::Other("invalid time in cursor"))?;
V::ChronoTime(Some(Box::new(t)))
}
FieldKind::Decimal => {
let d = s
.parse::<Decimal>()
.map_err(|_| ODataBuildError::Other("invalid decimal in cursor"))?;
V::Decimal(Some(Box::new(d)))
}
};
Ok(result)
}
pub fn build_cursor_predicate<E: EntityTrait>(
cursor: &CursorV1,
order: &ODataOrderBy,
fmap: &FieldMap<E>,
) -> ODataBuildResult<Condition>
where
E::Column: ColumnTrait + Copy,
{
if cursor.k.len() != order.0.len() {
return Err(ODataBuildError::Other(
"cursor keys count mismatch with order fields",
));
}
let mut cursor_values = Vec::new();
for (i, key_str) in cursor.k.iter().enumerate() {
let order_key = &order.0[i];
let field = fmap
.get(&order_key.field)
.ok_or_else(|| ODataBuildError::UnknownField(order_key.field.clone()))?;
let value = parse_cursor_value(field.kind, key_str)?;
cursor_values.push((field, value, order_key.dir));
}
let is_backward = cursor.d == "bwd";
let mut main_condition = Condition::any();
for i in 0..cursor_values.len() {
let mut prefix_condition = Condition::all();
for (field, value, _) in cursor_values.iter().take(i) {
prefix_condition = prefix_condition.add(Expr::col(field.col).eq(value.clone()));
}
let (field, value, dir) = &cursor_values[i];
let comparison = if is_backward {
match dir {
SortDir::Asc => Expr::col(field.col).lt(value.clone()),
SortDir::Desc => Expr::col(field.col).gt(value.clone()),
}
} else {
match dir {
SortDir::Asc => Expr::col(field.col).gt(value.clone()),
SortDir::Desc => Expr::col(field.col).lt(value.clone()),
}
};
prefix_condition = prefix_condition.add(comparison);
main_condition = main_condition.add(prefix_condition);
}
Ok(main_condition)
}
fn resolve_field<'a, E: EntityTrait>(
fld_map: &'a FieldMap<E>,
name: &str,
) -> Result<&'a Field<E>, ODataError> {
fld_map
.get(name)
.ok_or_else(|| ODataError::InvalidOrderByField(name.to_owned()))
}
pub fn ensure_tiebreaker(order: ODataOrderBy, tiebreaker: &str, dir: SortDir) -> ODataOrderBy {
order.ensure_tiebreaker(tiebreaker, dir)
}
pub fn build_cursor_for_model<E: EntityTrait>(
model: &E::Model,
order: &ODataOrderBy,
fmap: &FieldMap<E>,
primary_dir: SortDir,
filter_hash: Option<String>,
direction: &str, ) -> Result<CursorV1, ODataError> {
let mut k = Vec::with_capacity(order.0.len());
for key in &order.0 {
let s = fmap
.encode_model_key(model, &key.field)
.ok_or_else(|| ODataError::InvalidOrderByField(key.field.clone()))?;
k.push(s);
}
Ok(CursorV1 {
k,
o: primary_dir,
s: order.to_signed_tokens(),
f: filter_hash,
d: direction.to_owned(),
})
}
pub fn expr_to_condition<E: EntityTrait>(
expr: &core::Expr,
fmap: &FieldMap<E>,
) -> ODataBuildResult<Condition>
where
E::Column: ColumnTrait + Copy,
{
use core::CompareOperator as Op;
use core::Expr as X;
Ok(match expr {
X::And(a, b) => {
let left = expr_to_condition::<E>(a, fmap)?;
let right = expr_to_condition::<E>(b, fmap)?;
Condition::all().add(left).add(right) }
X::Or(a, b) => {
let left = expr_to_condition::<E>(a, fmap)?;
let right = expr_to_condition::<E>(b, fmap)?;
Condition::any().add(left).add(right) }
X::Not(x) => {
let inner = expr_to_condition::<E>(x, fmap)?;
Condition::all().add(inner).not()
}
X::Compare(lhs, op, rhs) => {
let (name, rhs_val) = match (&**lhs, &**rhs) {
(X::Identifier(name), X::Value(val)) => (name, val),
(X::Identifier(_), X::Identifier(_)) => {
return Err(ODataBuildError::Other(
"field-to-field comparison is not supported",
));
}
_ => return Err(ODataBuildError::Other("unsupported comparison form")),
};
let field = fmap
.get(name)
.ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
let col = field.col;
if matches!(rhs_val, core::Value::Null) {
return Ok(match op {
Op::Eq => Condition::all().add(Expr::col(col).is_null()),
Op::Ne => Condition::all().add(Expr::col(col).is_not_null()),
_ => return Err(ODataBuildError::UnsupportedOp(*op)),
});
}
let value = coerce(field.kind, rhs_val)?;
let expr = match op {
Op::Eq => Expr::col(col).eq(value),
Op::Ne => Expr::col(col).ne(value),
Op::Gt => Expr::col(col).gt(value),
Op::Ge => Expr::col(col).gte(value),
Op::Lt => Expr::col(col).lt(value),
Op::Le => Expr::col(col).lte(value),
};
Condition::all().add(expr)
}
X::In(l, list) => {
let X::Identifier(name) = &**l else {
return Err(ODataBuildError::Other("left side of IN must be a field"));
};
let f = fmap
.get(name)
.ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
let col = f.col;
let vals = coerce_many(f.kind, list)?;
if vals.is_empty() {
Condition::all().add(Expr::value(1).eq(0))
} else {
Condition::all().add(Expr::col(col).is_in(vals))
}
}
X::Function(fname, args) => {
let n = fname.to_ascii_lowercase();
match (n.as_str(), args.as_slice()) {
("contains", [X::Identifier(name), X::Value(core::Value::String(s))]) => {
let f = fmap
.get(name)
.ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
ensure_string_field(f, name)?;
Condition::all().add(Expr::col(f.col).like(like_contains(s)))
}
("startswith", [X::Identifier(name), X::Value(core::Value::String(s))]) => {
let f = fmap
.get(name)
.ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
ensure_string_field(f, name)?;
Condition::all().add(Expr::col(f.col).like(like_starts(s)))
}
("endswith", [X::Identifier(name), X::Value(core::Value::String(s))]) => {
let f = fmap
.get(name)
.ok_or_else(|| ODataBuildError::UnknownField(name.clone()))?;
ensure_string_field(f, name)?;
Condition::all().add(Expr::col(f.col).like(like_ends(s)))
}
_ => return Err(ODataBuildError::UnsupportedFn(fname.clone())),
}
}
X::Identifier(name) => return Err(ODataBuildError::BareIdentifier(name.clone())),
X::Value(_) => return Err(ODataBuildError::BareLiteral),
})
}
pub trait ODataExt<E: EntityTrait>: Sized {
fn apply_odata_filter(
self,
od_query: ODataQuery,
fld_map: &FieldMap<E>,
) -> ODataBuildResult<Self>;
}
impl<E> ODataExt<E> for sea_orm::Select<E>
where
E: EntityTrait,
E::Column: ColumnTrait + Copy,
{
fn apply_odata_filter(
self,
od_query: ODataQuery,
fld_map: &FieldMap<E>,
) -> ODataBuildResult<Self> {
match od_query.filter() {
Some(ast) => {
let cond = expr_to_condition::<E>(ast, fld_map)?;
Ok(self.filter(cond))
}
None => Ok(self),
}
}
}
pub trait CursorApplyExt<E: EntityTrait>: Sized {
fn apply_cursor_forward(
self,
cursor: &CursorV1,
order: &ODataOrderBy,
fld_map: &FieldMap<E>,
) -> ODataBuildResult<Self>;
}
impl<E> CursorApplyExt<E> for sea_orm::Select<E>
where
E: EntityTrait,
E::Column: ColumnTrait + Copy,
{
fn apply_cursor_forward(
self,
cursor: &CursorV1,
order: &ODataOrderBy,
fld_map: &FieldMap<E>,
) -> ODataBuildResult<Self> {
let cond = build_cursor_predicate(cursor, order, fld_map)?;
Ok(self.filter(cond))
}
}
pub trait ODataOrderExt<E: EntityTrait>: Sized {
fn apply_odata_order(
self,
order: &ODataOrderBy,
fld_map: &FieldMap<E>,
) -> ODataBuildResult<Self>;
}
impl<E> ODataOrderExt<E> for sea_orm::Select<E>
where
E: EntityTrait,
E::Column: ColumnTrait + Copy,
{
fn apply_odata_order(
self,
order: &ODataOrderBy,
fld_map: &FieldMap<E>,
) -> ODataBuildResult<Self> {
let mut query = self;
for order_key in &order.0 {
let field = fld_map
.get(&order_key.field)
.ok_or_else(|| ODataBuildError::UnknownField(order_key.field.clone()))?;
let sea_order = match order_key.dir {
SortDir::Asc => Order::Asc,
SortDir::Desc => Order::Desc,
};
query = query.order_by(field.col, sea_order);
}
Ok(query)
}
}
pub trait ODataOrderPageExt<E: EntityTrait>: Sized {
fn apply_odata_order_page(
self,
order: &ODataOrderBy,
fld_map: &FieldMap<E>,
) -> Result<Self, ODataError>;
}
impl<E> ODataOrderPageExt<E> for sea_orm::Select<E>
where
E: EntityTrait,
E::Column: ColumnTrait + Copy,
{
fn apply_odata_order_page(
self,
order: &ODataOrderBy,
fld_map: &FieldMap<E>,
) -> Result<Self, ODataError> {
let mut query = self;
for order_key in &order.0 {
let field = resolve_field(fld_map, &order_key.field)?;
let sea_order = match order_key.dir {
SortDir::Asc => Order::Asc,
SortDir::Desc => Order::Desc,
};
query = query.order_by(field.col, sea_order);
}
Ok(query)
}
}
pub trait ODataQueryExt<E: EntityTrait>: Sized {
fn apply_odata_query(
self,
query: &ODataQuery,
fld_map: &FieldMap<E>,
tiebreaker: (&str, SortDir),
) -> ODataBuildResult<Self>;
}
impl<E> ODataQueryExt<E> for sea_orm::Select<E>
where
E: EntityTrait,
E::Column: ColumnTrait + Copy,
{
fn apply_odata_query(
self,
query: &ODataQuery,
fld_map: &FieldMap<E>,
tiebreaker: (&str, SortDir),
) -> ODataBuildResult<Self> {
let mut select = self;
if let Some(ast) = query.filter.as_deref() {
let cond = expr_to_condition::<E>(ast, fld_map)?;
select = select.filter(cond);
}
let effective_order = ensure_tiebreaker(query.order.clone(), tiebreaker.0, tiebreaker.1);
if let Some(cursor) = &query.cursor {
select = select.apply_cursor_forward(cursor, &effective_order, fld_map)?;
}
select = select.apply_odata_order(&effective_order, fld_map)?;
Ok(select)
}
}
pub use modkit_odata::{Page, PageInfo};
fn clamp_limit(req: Option<u64>, cfg: LimitCfg) -> u64 {
let mut l = req.unwrap_or(cfg.default);
if l == 0 {
l = 1;
}
if l > cfg.max {
l = cfg.max;
}
l
}
pub async fn paginate_with_odata<E, D, F, C>(
select: sea_orm::Select<E>,
conn: &C,
q: &ODataQuery,
fmap: &FieldMap<E>,
tiebreaker: (&str, SortDir), limit_cfg: LimitCfg, model_to_domain: F,
) -> Result<Page<D>, ODataError>
where
E: EntityTrait,
E::Column: ColumnTrait + Copy,
F: Fn(E::Model) -> D + Copy,
C: DBRunner,
{
let limit = clamp_limit(q.limit, limit_cfg);
let fetch = limit + 1;
let effective_order = if let Some(cur) = &q.cursor {
modkit_odata::ODataOrderBy::from_signed_tokens(&cur.s)
.map_err(|_| ODataError::InvalidCursor)?
} else {
q.order
.clone()
.ensure_tiebreaker(tiebreaker.0, tiebreaker.1)
};
if let Some(cur) = &q.cursor
&& let (Some(h), Some(cf)) = (q.filter_hash.as_deref(), cur.f.as_deref())
&& h != cf
{
return Err(ODataError::FilterMismatch);
}
let mut s = select;
if let Some(ast) = q.filter.as_deref() {
s = s.filter(
expr_to_condition::<E>(ast, fmap)
.map_err(|e| ODataError::InvalidFilter(e.to_string()))?,
);
}
let is_backward = q.cursor.as_ref().is_some_and(|c| c.d == "bwd");
if let Some(cursor) = &q.cursor {
s = s.filter(
build_cursor_predicate(cursor, &effective_order, fmap)
.map_err(|_| ODataError::InvalidCursor)?,
);
}
let query_order = if is_backward {
effective_order.clone().reverse_directions()
} else {
effective_order.clone()
};
s = s.apply_odata_order_page(&query_order, fmap)?;
s = s.limit(fetch);
#[allow(clippy::disallowed_methods)]
let mut rows = match DBRunnerInternal::as_seaorm(conn) {
SeaOrmRunner::Conn(db) => s.all(db).await,
SeaOrmRunner::Tx(tx) => s.all(tx).await,
}
.map_err(|e| ODataError::Db(e.to_string()))?;
let has_more = (rows.len() as u64) > limit;
if is_backward {
if has_more {
rows.pop();
}
rows.reverse();
} else if has_more {
rows.truncate(usize::try_from(limit).unwrap_or(usize::MAX));
}
let next_cursor = if is_backward {
build_cursor(&rows, &effective_order, fmap, tiebreaker, q, true, "fwd")?
} else if has_more {
build_cursor(&rows, &effective_order, fmap, tiebreaker, q, true, "fwd")?
} else {
None
};
let prev_cursor = if is_backward {
if has_more {
build_cursor(&rows, &effective_order, fmap, tiebreaker, q, false, "bwd")?
} else {
None
}
} else if q.cursor.is_some() {
build_cursor(&rows, &effective_order, fmap, tiebreaker, q, false, "bwd")?
} else {
None
};
let items = rows.into_iter().map(model_to_domain).collect();
Ok(Page {
items,
page_info: PageInfo {
next_cursor,
prev_cursor,
limit,
},
})
}
fn build_cursor<E: EntityTrait>(
rows: &[E::Model],
effective_order: &ODataOrderBy,
fmap: &FieldMap<E>,
tiebreaker: (&str, SortDir),
q: &ODataQuery,
last: bool,
direction: &str,
) -> Result<Option<String>, ODataError> {
if last { rows.last() } else { rows.first() }
.map(|m| {
build_cursor_for_model::<E>(
m,
effective_order,
fmap,
tiebreaker.1,
q.filter_hash.clone(),
direction,
)
.and_then(|c| c.encode().map_err(|_| ODataError::InvalidCursor))
})
.transpose()
}