use std::ops::{Bound, RangeBounds};
use obj_core::codec::Dynamic;
use obj_core::{Error, Result};
use crate::Db;
use crate::Document;
type FilterFn<T> = Box<dyn Fn(&T) -> bool + 'static>;
type SortKeyFn<T> = Box<dyn Fn(&T) -> Result<Vec<u8>> + 'static>;
pub const MAX_SORT_BUFFER: usize = 100_000;
#[derive(Debug, Clone)]
enum Source {
Full,
IndexRange {
name: String,
start: Bound<Vec<u8>>,
end: Bound<Vec<u8>>,
},
}
pub struct Query<'db, T: Document> {
db: &'db Db,
source: Source,
filters: Vec<FilterFn<T>>,
limit: Option<usize>,
sort_key: Option<SortKeyFn<T>>,
sort_buffer_limit: Option<usize>,
}
impl<'db, T: Document> Query<'db, T>
where
T: Send + 'static,
{
pub(crate) fn new(db: &'db Db) -> Self {
Self {
db,
source: Source::Full,
filters: Vec::new(),
limit: None,
sort_key: None,
sort_buffer_limit: None,
}
}
#[must_use]
pub fn filter<F>(mut self, predicate: F) -> Self
where
F: Fn(&T) -> bool + 'static,
{
self.filters.push(Box::new(predicate));
self
}
#[must_use]
pub fn limit(mut self, n: usize) -> Self {
self.limit = Some(n);
self
}
pub fn index_range<R>(mut self, name: &str, range: R) -> Result<Self>
where
R: RangeBounds<Dynamic>,
{
let start = encode_bound(range.start_bound())?;
let end = encode_bound(range.end_bound())?;
self.source = Source::IndexRange {
name: name.to_owned(),
start,
end,
};
Ok(self)
}
#[must_use]
pub fn sort_by<F>(mut self, key: F) -> Self
where
F: Fn(&T) -> Dynamic + 'static,
{
let encoded: SortKeyFn<T> = Box::new(move |doc: &T| {
let dynamic = key(doc);
obj_core::index::encode_field(&dynamic)
.map(obj_core::index::EncodedIndexKey::into_bytes)
.map_err(|e| Error::SortKeyEncode {
source: Box::new(e),
})
});
self.sort_key = Some(encoded);
self
}
#[must_use]
pub fn sort_by_bytes<F>(mut self, key: F) -> Self
where
F: Fn(&T) -> Vec<u8> + 'static,
{
let encoded: SortKeyFn<T> = Box::new(move |doc: &T| Ok(key(doc)));
self.sort_key = Some(encoded);
self
}
#[must_use]
pub fn sort_buffer_limit(mut self, n: usize) -> Self {
self.sort_buffer_limit = Some(n);
self
}
pub fn fetch(self) -> Result<Vec<T>> {
#[cfg(feature = "tracing")]
let span = tracing::debug_span!("query.execute", kind = tracing::field::Empty);
#[cfg(feature = "tracing")]
let _guard = span.enter();
#[cfg(feature = "tracing")]
span.record("kind", query_kind(&self.source));
self.db.read_transaction(|tx| {
let coll = tx.collection::<T>()?;
if self.sort_key.is_some() {
fetch_sorted(&coll, &self)
} else {
fetch_unsorted(&coll, &self)
}
})
}
pub fn count(&self) -> Result<u64> {
#[cfg(feature = "tracing")]
let span = tracing::debug_span!("query.execute", kind = tracing::field::Empty);
#[cfg(feature = "tracing")]
let _guard = span.enter();
#[cfg(feature = "tracing")]
span.record("kind", query_kind(&self.source));
self.db.read_transaction(|tx| {
let coll = tx.collection::<T>()?;
let total = if self.filters.is_empty() {
count_fast(&coll, &self.source)?
} else {
count_slow(&coll, self)?
};
Ok(apply_count_limit(total, self.limit))
})
}
}
#[cfg(feature = "tracing")]
fn query_kind(source: &Source) -> &'static str {
match source {
Source::Full => "filter",
Source::IndexRange { .. } => "index",
}
}
fn apply_count_limit(total: u64, limit: Option<usize>) -> u64 {
match limit {
Some(n) => total.min(u64::try_from(n).unwrap_or(u64::MAX)),
None => total,
}
}
fn fetch_unsorted<T>(coll: &crate::Collection<'_, T>, q: &Query<'_, T>) -> Result<Vec<T>>
where
T: Document + Send + 'static,
{
if q.limit == Some(0) {
return Ok(Vec::new());
}
let mut out: Vec<T> = Vec::new();
for_each_candidate(coll, q, |doc| {
if !q.filters.iter().all(|f| f(&doc)) {
return Ok(true);
}
out.push(doc);
if let Some(n) = q.limit {
if out.len() >= n {
return Ok(false);
}
}
Ok(true)
})?;
Ok(out)
}
fn fetch_sorted<T>(coll: &crate::Collection<'_, T>, q: &Query<'_, T>) -> Result<Vec<T>>
where
T: Document + Send + 'static,
{
let cap = q.sort_buffer_limit.unwrap_or(MAX_SORT_BUFFER);
let sort_key = q
.sort_key
.as_ref()
.ok_or(Error::InvalidArgument("fetch_sorted without sort_key"))?;
let mut buf: Vec<(Vec<u8>, T)> = Vec::new();
for_each_candidate(coll, q, |doc| {
if !q.filters.iter().all(|f| f(&doc)) {
return Ok(true);
}
if buf.len() >= cap {
return Err(Error::SortBufferExceeded { limit: cap });
}
let key_bytes = sort_key(&doc)?;
buf.push((key_bytes, doc));
Ok(true)
})?;
buf.sort_by(|a, b| a.0.cmp(&b.0));
let truncated_len = match q.limit {
Some(n) => buf.len().min(n),
None => buf.len(),
};
let mut out: Vec<T> = Vec::with_capacity(truncated_len);
for (_k, d) in buf.into_iter().take(truncated_len) {
out.push(d);
}
Ok(out)
}
fn for_each_candidate<T, F>(
coll: &crate::Collection<'_, T>,
q: &Query<'_, T>,
mut f: F,
) -> Result<()>
where
T: Document + Send + 'static,
F: FnMut(T) -> Result<bool>,
{
match &q.source {
Source::Full => {
let docs = coll.all()?;
for (_id, doc) in docs {
if !f(doc)? {
return Ok(());
}
}
}
Source::IndexRange { name, start, end } => {
let iter = coll.index_range_encoded(name, clone_bound(start), clone_bound(end))?;
for step in iter {
let (_key, doc) = step?;
if !f(doc)? {
return Ok(());
}
}
}
}
Ok(())
}
fn count_fast<T>(coll: &crate::Collection<'_, T>, source: &Source) -> Result<u64>
where
T: Document + Send + 'static,
{
match source {
Source::Full => coll.count_all(),
Source::IndexRange { name, start, end } => {
let kind = coll.index_kind(name)?;
if kind == obj_core::IndexKind::Each {
coll.count_distinct_ids_in_range_encoded(name, clone_bound(start), clone_bound(end))
} else {
coll.count_index_range_encoded(name, clone_bound(start), clone_bound(end))
}
}
}
}
fn count_slow<T>(coll: &crate::Collection<'_, T>, q: &Query<'_, T>) -> Result<u64>
where
T: Document + Send + 'static,
{
let mut n: u64 = 0;
for_each_candidate(coll, q, |doc| {
if q.filters.iter().all(|f| f(&doc)) {
n = n.checked_add(1).ok_or(Error::BTreeInvariantViolated {
reason: "slow-path count exceeds u64",
})?;
}
Ok(true)
})?;
Ok(n)
}
fn encode_bound(b: Bound<&Dynamic>) -> Result<Bound<Vec<u8>>> {
match b {
Bound::Included(v) => Ok(Bound::Included(
obj_core::index::encode_field(v)?.into_bytes(),
)),
Bound::Excluded(v) => Ok(Bound::Excluded(
obj_core::index::encode_field(v)?.into_bytes(),
)),
Bound::Unbounded => Ok(Bound::Unbounded),
}
}
fn clone_bound(b: &Bound<Vec<u8>>) -> Bound<Vec<u8>> {
match b {
Bound::Included(v) => Bound::Included(v.clone()),
Bound::Excluded(v) => Bound::Excluded(v.clone()),
Bound::Unbounded => Bound::Unbounded,
}
}