use std::marker::PhantomData;
use std::ops::{Bound, RangeBounds};
use std::sync::Arc;
use obj_core::codec::Dynamic;
use obj_core::{Document, Result};
use crate::asynchronous::db::unblock;
use crate::Db;
type FilterFn<T> = Box<dyn Fn(&T) -> bool + Send + 'static>;
type SortDynamicFn<T> = Box<dyn Fn(&T) -> Dynamic + Send + 'static>;
type SortBytesFn<T> = Box<dyn Fn(&T) -> Vec<u8> + Send + 'static>;
enum SortKey<T> {
Dynamic(SortDynamicFn<T>),
Bytes(SortBytesFn<T>),
}
#[derive(Debug, Clone)]
enum AsyncSource {
Full,
IndexRange {
name: String,
start: Bound<Dynamic>,
end: Bound<Dynamic>,
},
}
pub struct AsyncQuery<T> {
db: Arc<Db>,
source: AsyncSource,
filters: Vec<FilterFn<T>>,
limit: Option<usize>,
sort_key: Option<SortKey<T>>,
sort_buffer_limit: Option<usize>,
_phantom: PhantomData<fn() -> T>,
}
impl<T> std::fmt::Debug for AsyncQuery<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncQuery")
.field("source", &self.source)
.field("filters", &self.filters.len())
.field("limit", &self.limit)
.field("sort_key", &self.sort_key.is_some())
.field("sort_buffer_limit", &self.sort_buffer_limit)
.finish_non_exhaustive()
}
}
impl<T> AsyncQuery<T>
where
T: Document + Send + 'static,
{
pub(crate) fn new(db: Arc<Db>) -> Self {
Self {
db,
source: AsyncSource::Full,
filters: Vec::new(),
limit: None,
sort_key: None,
sort_buffer_limit: None,
_phantom: PhantomData,
}
}
#[must_use]
pub fn filter<F>(mut self, predicate: F) -> Self
where
F: Fn(&T) -> bool + Send + 'static,
{
self.filters.push(Box::new(predicate));
self
}
#[must_use]
pub fn limit(mut self, n: usize) -> Self {
self.limit = Some(n);
self
}
#[must_use]
pub fn index_range<R>(mut self, name: &str, range: R) -> Self
where
R: RangeBounds<Dynamic>,
{
self.source = AsyncSource::IndexRange {
name: name.to_owned(),
start: clone_dynamic_bound(range.start_bound()),
end: clone_dynamic_bound(range.end_bound()),
};
self
}
#[must_use]
pub fn sort_by<F>(mut self, key: F) -> Self
where
F: Fn(&T) -> Dynamic + Send + 'static,
{
self.sort_key = Some(SortKey::Dynamic(Box::new(key)));
self
}
#[must_use]
pub fn sort_by_bytes<F>(mut self, key: F) -> Self
where
F: Fn(&T) -> Vec<u8> + Send + 'static,
{
self.sort_key = Some(SortKey::Bytes(Box::new(key)));
self
}
#[must_use]
pub fn sort_buffer_limit(mut self, n: usize) -> Self {
self.sort_buffer_limit = Some(n);
self
}
pub async fn fetch(self) -> Result<Vec<T>> {
let AsyncQuery {
db,
source,
filters,
limit,
sort_key,
sort_buffer_limit,
_phantom,
} = self;
unblock(move || {
let q = build_blocking_query::<T>(
&db,
source,
filters,
limit,
sort_key,
sort_buffer_limit,
)?;
q.fetch()
})
.await
}
pub async fn count(self) -> Result<u64> {
let AsyncQuery {
db,
source,
filters,
limit,
sort_key,
sort_buffer_limit,
_phantom,
} = self;
unblock(move || {
let q = build_blocking_query::<T>(
&db,
source,
filters,
limit,
sort_key,
sort_buffer_limit,
)?;
q.count()
})
.await
}
}
fn build_blocking_query<T>(
db: &Db,
source: AsyncSource,
filters: Vec<FilterFn<T>>,
limit: Option<usize>,
sort_key: Option<SortKey<T>>,
sort_buffer_limit: Option<usize>,
) -> Result<crate::Query<'_, T>>
where
T: Document + Send + 'static,
{
let mut q = db.query::<T>();
match source {
AsyncSource::Full => {}
AsyncSource::IndexRange { name, start, end } => {
q = q.index_range(&name, (start, end))?;
}
}
for predicate in filters {
q = q.filter(move |doc| predicate(doc));
}
match sort_key {
Some(SortKey::Dynamic(f)) => {
q = q.sort_by(move |doc| f(doc));
}
Some(SortKey::Bytes(f)) => {
q = q.sort_by_bytes(move |doc| f(doc));
}
None => {}
}
if let Some(n) = limit {
q = q.limit(n);
}
if let Some(n) = sort_buffer_limit {
q = q.sort_buffer_limit(n);
}
Ok(q)
}
fn clone_dynamic_bound(b: Bound<&Dynamic>) -> Bound<Dynamic> {
match b {
Bound::Included(d) => Bound::Included(d.clone()),
Bound::Excluded(d) => Bound::Excluded(d.clone()),
Bound::Unbounded => Bound::Unbounded,
}
}