use super::{List, Query};
use crate::{Executor, Result, engine::eval::Func, schema::Load, stmt::IntoStatement};
use toasty_core::stmt::{self, Expr, ExprRecord, OrderBy, Projection, Value, VisitMut, visit_mut};
#[derive(Debug)]
pub struct Paginate<M> {
query: Query<List<M>>,
reverse: bool,
}
impl<M> Paginate<M> {
pub fn new(mut query: Query<List<M>>, per_page: usize) -> Self {
assert!(
query.untyped.limit.is_none(),
"pagination requires no limit clause"
);
assert!(
query.untyped.order_by.is_some(),
"pagination requires an order_by clause"
);
query.untyped.limit = Some(stmt::Limit {
limit: stmt::Value::from(per_page as i64).into(),
offset: None,
});
Self {
query,
reverse: false,
}
}
pub fn after(mut self, key: impl Into<stmt::Expr>) -> Self {
let Some(limit) = self.query.untyped.limit.as_mut() else {
panic!("pagination requires a limit clause");
};
limit.offset = Some(stmt::Offset::After(key.into()));
self.reverse = false;
self
}
pub fn before(mut self, key: impl Into<stmt::Expr>) -> Self {
let Some(limit) = self.query.untyped.limit.as_mut() else {
panic!("pagination requires a limit clause");
};
limit.offset = Some(stmt::Offset::After(key.into()));
self.reverse = true;
self
}
}
impl<M: Load> Paginate<M> {
pub async fn exec(mut self, executor: &mut dyn Executor) -> Result<super::Page<M::Output>> {
let page_size = match &self.query.untyped.limit {
Some(stmt::Limit {
limit: stmt::Expr::Value(stmt::Value::I64(n)),
..
}) => *n as usize,
_ => {
let res = executor
.exec_untyped(self.query.clone().into_statement().untyped)
.await?;
let stmt::Value::List(values) = res else {
todo!()
};
let items: Vec<M::Output> =
values.into_iter().map(M::load).collect::<Result<_>>()?;
return Ok(super::Page::new(
items,
Query::from_untyped(self.query.untyped),
None,
None,
));
}
};
let mut query_with_extra = self.query.clone();
if let Some(stmt::Limit { limit, .. }) = &mut query_with_extra.untyped.limit {
*limit = stmt::Value::from((page_size + 1) as i64).into();
}
let Some(order_by) = query_with_extra.untyped.order_by.as_mut() else {
panic!("pagination requires order by clause");
};
if self.reverse {
order_by.reverse();
}
let res = executor
.exec_untyped(query_with_extra.into_statement().untyped)
.await?;
let stmt::Value::List(mut items) = res else {
todo!()
};
let has_next = (items.len() > page_size) || self.reverse;
let has_prev = (items.len() > page_size) || !self.reverse;
items.truncate(page_size);
if self.reverse {
items.reverse();
}
let Some(order_by) = self.query.untyped.order_by.as_mut() else {
panic!("pagination requires order by clause");
};
let prev_cursor = match items.first() {
Some(first_item) if has_prev => {
extract_cursor(order_by, first_item).map(|cursor| cursor.into())
}
_ => None,
};
let next_cursor = match items.last() {
Some(last_item) if has_next => {
extract_cursor(order_by, last_item).map(|cursor| cursor.into())
}
_ => None,
};
let loaded_items: Vec<M::Output> = items.into_iter().map(M::load).collect::<Result<_>>()?;
Ok(super::Page::new(
loaded_items,
Query::from_untyped(self.query.untyped),
next_cursor,
prev_cursor,
))
}
}
impl<M> From<Query<List<M>>> for Paginate<M> {
fn from(value: Query<List<M>>) -> Self {
assert!(
value.untyped.limit.is_some(),
"pagination requires a limit clause"
);
assert!(
value.untyped.order_by.is_some(),
"pagination requires an order_by clause"
);
Paginate {
query: value,
reverse: false,
}
}
}
fn extract_cursor(order_by: &OrderBy, item: &Value) -> Option<Value> {
let record = ExprRecord::from_iter(order_by.exprs.iter().map(|order_by_expr| {
struct Visitor;
impl VisitMut for Visitor {
fn visit_expr_mut(&mut self, i: &mut stmt::Expr) {
match i {
stmt::Expr::Reference(stmt::ExprReference::Field { nesting, index })
if *nesting == 0 =>
{
*i = Expr::arg_project(0, Projection::from_index(*index))
}
_ => visit_mut::visit_expr_mut(self, i),
}
}
}
let mut expr = order_by_expr.expr.clone();
Visitor.visit_mut(&mut expr);
expr
}));
Func::from_stmt(Expr::Record(record), vec![item.infer_ty()])
.eval(&[item])
.ok()
}