use crate::core::DocId;
use std::collections::HashMap;
use super::*;
use crate::segment::reader::SegmentReader;
pub(super) use crate::columnar::owned::OwnedColumn;
pub struct TermsAggFactory {
pub field_name: String,
pub size: usize,
pub sub_agg_factories: Vec<(String, Box<dyn AggregatorFactory>)>,
}
impl AggregatorFactory for TermsAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field_name)
.map(|f| f.field_id);
let col = OwnedColumn::new(field_id, reader);
let dict_size = col.as_ref().map_or(0, |c| c.dict_size());
Box::new(TermsCollector {
col,
segment_data: reader as *const SegmentReader,
ordinal_buckets: Vec::with_capacity(dict_size),
sub_agg_factories: self.sub_agg_factories.as_slice()
as *const [(String, Box<dyn AggregatorFactory>)],
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut merged: HashMap<String, (u64, HashMap<String, Vec<AggregationResult>>)> =
HashMap::new();
for r in results {
if let AggregationResult::Bucket(br) = r {
for bucket in br.buckets {
if let BucketKey::String(key) = bucket.key {
let entry = merged.entry(key).or_insert_with(|| (0, HashMap::new()));
entry.0 += bucket.doc_count;
for (name, result) in bucket.sub_aggs {
entry.1.entry(name).or_default().push(result);
}
}
}
}
}
let mut bucket_entries: Vec<_> = merged.into_iter().collect();
bucket_entries.sort_by(|a, b| b.1.0.cmp(&a.1.0));
bucket_entries.truncate(self.size);
let result_buckets = bucket_entries
.into_iter()
.map(|(key, (doc_count, sub_partials))| {
let mut sub_aggs = HashMap::new();
for (_i, (name, factory)) in self.sub_agg_factories.iter().enumerate() {
if let Some(partials) = sub_partials.get(name) {
sub_aggs.insert(name.clone(), factory.merge_results(partials.clone()));
}
}
Bucket {
key: BucketKey::String(key),
doc_count,
sub_aggs,
}
})
.collect();
AggregationResult::Bucket(BucketResult {
buckets: result_buckets,
})
}
}
struct BucketState {
count: u64,
sub_collectors: Vec<(String, Box<dyn Aggregator>)>,
}
struct TermsCollector {
col: Option<OwnedColumn>,
segment_data: *const SegmentReader,
ordinal_buckets: Vec<Option<BucketState>>,
sub_agg_factories: *const [(String, Box<dyn AggregatorFactory>)],
}
unsafe impl Send for TermsCollector {}
impl Aggregator for TermsCollector {
fn collect(&mut self, doc_id: DocId) {
let Some(ord) = self
.col
.as_ref()
.and_then(|c| c.keyword_ordinal(doc_id.as_u32()))
else {
return;
};
let ordinal = ord as usize;
if ordinal >= self.ordinal_buckets.len() {
self.ordinal_buckets.resize_with(ordinal + 1, || None);
}
let reader = unsafe { &*self.segment_data };
let sub_factories = unsafe { &*self.sub_agg_factories };
let state = self.ordinal_buckets[ordinal].get_or_insert_with(|| {
let subs = sub_factories
.iter()
.map(|(name, factory)| (name.clone(), factory.create_collector(reader)))
.collect();
BucketState {
count: 0,
sub_collectors: subs,
}
});
state.count += 1;
for (_, collector) in &mut state.sub_collectors {
collector.collect(doc_id);
}
}
fn finish(self: Box<Self>) -> AggregationResult {
if let Some(col) = self.col.as_ref() {
col.ensure_dict();
}
let mut buckets: Vec<Bucket> = self
.ordinal_buckets
.into_iter()
.enumerate()
.filter_map(|(ordinal, state)| {
let state = state?;
let key = self
.col
.as_ref()
.and_then(|c| c.ordinal_to_string(ordinal as u32))
.unwrap_or("?")
.to_string();
let sub_aggs: HashMap<String, AggregationResult> = state
.sub_collectors
.into_iter()
.map(|(name, collector)| (name, collector.finish()))
.collect();
Some(Bucket {
key: BucketKey::String(key),
doc_count: state.count,
sub_aggs,
})
})
.collect();
buckets.sort_by(|a, b| b.doc_count.cmp(&a.doc_count));
AggregationResult::Bucket(BucketResult { buckets })
}
}
pub struct RangeAggFactory {
pub field_name: String,
pub ranges: Vec<RangeDef>,
}
impl AggregatorFactory for RangeAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field_name)
.map(|f| f.field_id);
let col = OwnedColumn::new(field_id, reader);
Box::new(RangeCollector {
col,
ranges: self.ranges.clone(),
counts: vec![0u64; self.ranges.len()],
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut merged_counts = vec![0u64; self.ranges.len()];
for r in results {
if let AggregationResult::Bucket(br) = r {
for (i, bucket) in br.buckets.iter().enumerate() {
if i < merged_counts.len() {
merged_counts[i] += bucket.doc_count;
}
}
}
}
let buckets = self
.ranges
.iter()
.zip(merged_counts.iter())
.map(|(range, &count)| Bucket {
key: BucketKey::Range {
from: range.from,
to: range.to,
},
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect();
AggregationResult::Bucket(BucketResult { buckets })
}
}
struct RangeCollector {
col: Option<OwnedColumn>,
ranges: Vec<RangeDef>,
counts: Vec<u64>,
}
unsafe impl Send for RangeCollector {}
impl Aggregator for RangeCollector {
fn collect(&mut self, doc_id: DocId) {
let Some(v) = self
.col
.as_ref()
.and_then(|c| c.numeric_value(doc_id.as_u32()))
else {
return;
};
for (i, range) in self.ranges.iter().enumerate() {
let above_from = range.from.map_or(true, |f| v >= f);
let below_to = range.to.map_or(true, |t| v < t);
if above_from && below_to {
self.counts[i] += 1;
}
}
}
fn finish(self: Box<Self>) -> AggregationResult {
let buckets = self
.ranges
.iter()
.zip(self.counts.iter())
.map(|(range, &count)| Bucket {
key: BucketKey::Range {
from: range.from,
to: range.to,
},
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect();
AggregationResult::Bucket(BucketResult { buckets })
}
}
pub struct HistogramAggFactory {
pub field_name: String,
pub interval: f64,
}
impl AggregatorFactory for HistogramAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field_name)
.map(|f| f.field_id);
let col = OwnedColumn::new(field_id, reader);
Box::new(HistogramCollector {
col,
interval: self.interval,
buckets: HashMap::new(),
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut merged: HashMap<i64, u64> = HashMap::new();
for r in results {
if let AggregationResult::Bucket(br) = r {
for bucket in br.buckets {
if let BucketKey::Number(key) = bucket.key {
*merged.entry(key as i64).or_insert(0) += bucket.doc_count;
}
}
}
}
let mut buckets: Vec<Bucket> = merged
.into_iter()
.map(|(key, count)| Bucket {
key: BucketKey::Number(key as f64),
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect();
buckets.sort_by(|a, b| {
let ka = if let BucketKey::Number(n) = a.key {
n
} else {
0.0
};
let kb = if let BucketKey::Number(n) = b.key {
n
} else {
0.0
};
ka.partial_cmp(&kb).unwrap_or(std::cmp::Ordering::Equal)
});
AggregationResult::Bucket(BucketResult { buckets })
}
}
struct HistogramCollector {
col: Option<OwnedColumn>,
interval: f64,
buckets: HashMap<i64, u64>,
}
unsafe impl Send for HistogramCollector {}
impl Aggregator for HistogramCollector {
fn collect(&mut self, doc_id: DocId) {
let Some(v) = self
.col
.as_ref()
.and_then(|c| c.numeric_value(doc_id.as_u32()))
else {
return;
};
let bucket_key = (v / self.interval).floor() as i64 * self.interval as i64;
*self.buckets.entry(bucket_key).or_insert(0) += 1;
}
fn finish(self: Box<Self>) -> AggregationResult {
let mut buckets: Vec<Bucket> = self
.buckets
.into_iter()
.map(|(key, count)| Bucket {
key: BucketKey::Number(key as f64),
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect();
buckets.sort_by(|a, b| {
let ka = if let BucketKey::Number(n) = a.key {
n
} else {
0.0
};
let kb = if let BucketKey::Number(n) = b.key {
n
} else {
0.0
};
ka.partial_cmp(&kb).unwrap_or(std::cmp::Ordering::Equal)
});
AggregationResult::Bucket(BucketResult { buckets })
}
}
pub struct DateHistogramAggFactory {
pub field_name: String,
pub interval: DateInterval,
}
impl AggregatorFactory for DateHistogramAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field_name)
.map(|f| f.field_id);
let col = OwnedColumn::new(field_id, reader);
Box::new(DateHistogramCollector {
col,
interval: self.interval.clone(),
buckets: HashMap::new(),
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut merged: HashMap<i64, u64> = HashMap::new();
for r in results {
if let AggregationResult::Bucket(br) = r {
for bucket in br.buckets {
if let BucketKey::Number(key) = bucket.key {
*merged.entry(key as i64).or_insert(0) += bucket.doc_count;
}
}
}
}
let mut buckets: Vec<Bucket> = merged
.into_iter()
.map(|(key, count)| Bucket {
key: BucketKey::Number(key as f64),
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect();
buckets.sort_by(|a, b| {
let ka = if let BucketKey::Number(n) = a.key {
n
} else {
0.0
};
let kb = if let BucketKey::Number(n) = b.key {
n
} else {
0.0
};
ka.partial_cmp(&kb).unwrap_or(std::cmp::Ordering::Equal)
});
AggregationResult::Bucket(BucketResult { buckets })
}
}
struct DateHistogramCollector {
col: Option<OwnedColumn>,
interval: DateInterval,
buckets: HashMap<i64, u64>,
}
unsafe impl Send for DateHistogramCollector {}
impl DateHistogramCollector {
fn bucket_key(&self, epoch_millis: f64) -> i64 {
match &self.interval {
DateInterval::Fixed(ms) => ((epoch_millis / ms).floor() as i64) * (*ms as i64),
DateInterval::Calendar(cal) => calendar_floor(epoch_millis as i64, cal),
}
}
}
fn calendar_floor(epoch_ms: i64, interval: &CalendarInterval) -> i64 {
const MS_PER_SEC: i64 = 1_000;
const MS_PER_MIN: i64 = 60 * MS_PER_SEC;
const MS_PER_HOUR: i64 = 60 * MS_PER_MIN;
const MS_PER_DAY: i64 = 24 * MS_PER_HOUR;
match interval {
CalendarInterval::Minute => (epoch_ms / MS_PER_MIN) * MS_PER_MIN,
CalendarInterval::Hour => (epoch_ms / MS_PER_HOUR) * MS_PER_HOUR,
CalendarInterval::Day => (epoch_ms / MS_PER_DAY) * MS_PER_DAY,
CalendarInterval::Week => {
let days = epoch_ms / MS_PER_DAY;
let week_start = days - ((days + 3) % 7); week_start * MS_PER_DAY
}
CalendarInterval::Month | CalendarInterval::Quarter | CalendarInterval::Year => {
let days_since_epoch = epoch_ms / MS_PER_DAY;
let (y, m, _d) = days_to_ymd(days_since_epoch);
let floored_month = match interval {
CalendarInterval::Month => m,
CalendarInterval::Quarter => ((m - 1) / 3) * 3 + 1,
CalendarInterval::Year => 1,
_ => unreachable!(),
};
let floored_year = if matches!(interval, CalendarInterval::Year) {
y
} else {
y
};
ymd_to_epoch_ms(floored_year, floored_month, 1)
}
}
}
fn days_to_ymd(days: i64) -> (i32, u32, u32) {
let z = days + 719468;
let era = if z >= 0 { z } else { z - 146096 } / 146097;
let doe = (z - era * 146097) as u32;
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
let y = yoe as i64 + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = doy - (153 * mp + 2) / 5 + 1;
let m = if mp < 10 { mp + 3 } else { mp - 9 };
let y = if m <= 2 { y + 1 } else { y };
(y as i32, m, d)
}
fn ymd_to_epoch_ms(y: i32, m: u32, d: u32) -> i64 {
let y = if m <= 2 { y as i64 - 1 } else { y as i64 };
let era = if y >= 0 { y } else { y - 399 } / 400;
let yoe = (y - era * 400) as u32;
let doy = (153 * (if m > 2 { m - 3 } else { m + 9 }) + 2) / 5 + d - 1;
let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
let days = era * 146097 + doe as i64 - 719468;
days * 86_400_000
}
impl Aggregator for DateHistogramCollector {
fn collect(&mut self, doc_id: DocId) {
let Some(v) = self
.col
.as_ref()
.and_then(|c| c.numeric_value(doc_id.as_u32()))
else {
return;
};
let key = self.bucket_key(v);
*self.buckets.entry(key).or_insert(0) += 1;
}
fn finish(self: Box<Self>) -> AggregationResult {
let mut buckets: Vec<Bucket> = self
.buckets
.into_iter()
.map(|(key, count)| Bucket {
key: BucketKey::Number(key as f64),
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect();
buckets.sort_by(|a, b| {
let ka = if let BucketKey::Number(n) = a.key {
n
} else {
0.0
};
let kb = if let BucketKey::Number(n) = b.key {
n
} else {
0.0
};
ka.partial_cmp(&kb).unwrap_or(std::cmp::Ordering::Equal)
});
AggregationResult::Bucket(BucketResult { buckets })
}
}
pub struct NestedAggFactory {
pub path: String,
pub sub_agg_factories: Vec<(String, Box<dyn AggregatorFactory>)>,
}
impl AggregatorFactory for NestedAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let parent_bitset = reader.parent_bitset().map(|b| b.to_vec());
let sub_collectors: Vec<(String, Box<dyn Aggregator>)> = self
.sub_agg_factories
.iter()
.map(|(name, factory)| (name.clone(), factory.create_collector(reader)))
.collect();
Box::new(NestedAggregator {
parent_bitset,
sub_collectors,
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut total_count = 0u64;
let mut sub_results: HashMap<String, Vec<AggregationResult>> = HashMap::new();
for r in results {
if let AggregationResult::Bucket(br) = r {
for b in br.buckets {
total_count += b.doc_count;
for (name, sub_r) in b.sub_aggs {
sub_results.entry(name).or_default().push(sub_r);
}
}
}
}
let mut merged_sub_aggs = HashMap::new();
for ((name, factory), (_, results)) in self.sub_agg_factories.iter().zip(sub_results.iter())
{
merged_sub_aggs.insert(name.clone(), factory.merge_results(results.clone()));
}
AggregationResult::Bucket(BucketResult {
buckets: vec![Bucket {
key: BucketKey::String("nested".into()),
doc_count: total_count,
sub_aggs: merged_sub_aggs,
}],
})
}
}
struct NestedAggregator {
parent_bitset: Option<Vec<bool>>,
sub_collectors: Vec<(String, Box<dyn Aggregator>)>,
}
unsafe impl Send for NestedAggregator {}
impl NestedAggregator {
fn children_of(&self, parent_doc: u32) -> Vec<u32> {
let Some(bitset) = &self.parent_bitset else {
return vec![];
};
let mut children = Vec::new();
let start = parent_doc as usize + 1;
for i in start..bitset.len() {
if bitset[i] {
break; }
children.push(i as u32);
}
children
}
}
impl Aggregator for NestedAggregator {
fn collect(&mut self, doc_id: DocId) {
let children = self.children_of(doc_id.as_u32());
for child_id in children {
for (_, collector) in &mut self.sub_collectors {
collector.collect(DocId::new(child_id));
}
}
}
fn finish(self: Box<Self>) -> AggregationResult {
let mut sub_aggs = HashMap::new();
let mut total_children = 0u64;
for (name, collector) in self.sub_collectors {
let result = collector.finish();
if let AggregationResult::Bucket(ref br) = result {
for b in &br.buckets {
total_children += b.doc_count;
}
}
sub_aggs.insert(name, result);
}
AggregationResult::Bucket(BucketResult {
buckets: vec![Bucket {
key: BucketKey::String("nested".into()),
doc_count: total_children,
sub_aggs,
}],
})
}
}
pub struct ReverseNestedAggFactory {
pub sub_agg_factories: Vec<(String, Box<dyn AggregatorFactory>)>,
}
impl AggregatorFactory for ReverseNestedAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let parent_bitset = reader.parent_bitset().map(|b| b.to_vec());
let sub_collectors: Vec<(String, Box<dyn Aggregator>)> = self
.sub_agg_factories
.iter()
.map(|(name, factory)| (name.clone(), factory.create_collector(reader)))
.collect();
Box::new(ReverseNestedAggregator {
parent_bitset,
sub_collectors,
seen_parents: std::collections::HashSet::new(),
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut total_count = 0u64;
let mut sub_results: HashMap<String, Vec<AggregationResult>> = HashMap::new();
for r in results {
if let AggregationResult::Bucket(br) = r {
for b in br.buckets {
total_count += b.doc_count;
for (name, sub_r) in b.sub_aggs {
sub_results.entry(name).or_default().push(sub_r);
}
}
}
}
let mut merged_sub_aggs = HashMap::new();
for ((name, factory), (_, results)) in self.sub_agg_factories.iter().zip(sub_results.iter())
{
merged_sub_aggs.insert(name.clone(), factory.merge_results(results.clone()));
}
AggregationResult::Bucket(BucketResult {
buckets: vec![Bucket {
key: BucketKey::String("reverse_nested".into()),
doc_count: total_count,
sub_aggs: merged_sub_aggs,
}],
})
}
}
struct ReverseNestedAggregator {
parent_bitset: Option<Vec<bool>>,
sub_collectors: Vec<(String, Box<dyn Aggregator>)>,
seen_parents: std::collections::HashSet<u32>,
}
unsafe impl Send for ReverseNestedAggregator {}
impl ReverseNestedAggregator {
fn find_parent(&self, nested_doc: u32) -> Option<u32> {
let bitset = self.parent_bitset.as_ref()?;
let mut i = nested_doc as usize;
while i > 0 {
if i < bitset.len() && bitset[i] {
return Some(i as u32);
}
i -= 1;
}
if !bitset.is_empty() && bitset[0] {
Some(0)
} else {
None
}
}
}
impl Aggregator for ReverseNestedAggregator {
fn collect(&mut self, doc_id: DocId) {
if let Some(parent_id) = self.find_parent(doc_id.as_u32()) {
if self.seen_parents.insert(parent_id) {
for (_, collector) in &mut self.sub_collectors {
collector.collect(DocId::new(parent_id));
}
}
}
}
fn finish(self: Box<Self>) -> AggregationResult {
let parent_count = self.seen_parents.len() as u64;
let mut sub_aggs = HashMap::new();
for (name, collector) in self.sub_collectors {
sub_aggs.insert(name, collector.finish());
}
AggregationResult::Bucket(BucketResult {
buckets: vec![Bucket {
key: BucketKey::String("reverse_nested".into()),
doc_count: parent_count,
sub_aggs,
}],
})
}
}
pub struct GeohashGridAggFactory {
pub field_name: String,
pub precision: usize,
pub size: usize,
}
impl AggregatorFactory for GeohashGridAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field_name)
.map(|f| f.field_id);
let store = field_id.and_then(|fid| reader.geo_points(fid));
Box::new(GeohashGridCollector {
store,
precision: self.precision,
buckets: HashMap::new(),
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut merged: HashMap<String, u64> = HashMap::new();
for r in results {
if let AggregationResult::Bucket(br) = r {
for b in br.buckets {
if let BucketKey::String(key) = b.key {
*merged.entry(key).or_insert(0) += b.doc_count;
}
}
}
}
let mut buckets: Vec<Bucket> = merged
.into_iter()
.map(|(key, count)| Bucket {
key: BucketKey::String(key),
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect();
buckets.sort_by(|a, b| b.doc_count.cmp(&a.doc_count));
buckets.truncate(self.size);
AggregationResult::Bucket(BucketResult { buckets })
}
}
struct GeohashGridCollector {
store: Option<crate::spatial::geo::GeoPointStore>,
precision: usize,
buckets: HashMap<String, u64>,
}
unsafe impl Send for GeohashGridCollector {}
impl Aggregator for GeohashGridCollector {
fn collect(&mut self, doc_id: DocId) {
if let Some(store) = &self.store {
if let Some(point) = store.get(doc_id.as_u32()) {
if let Ok(hash) = geohash::encode(
geohash::Coord {
x: point.lon,
y: point.lat,
},
self.precision,
) {
*self.buckets.entry(hash).or_insert(0) += 1;
}
}
}
}
fn finish(self: Box<Self>) -> AggregationResult {
let mut buckets: Vec<Bucket> = self
.buckets
.into_iter()
.map(|(key, count)| Bucket {
key: BucketKey::String(key),
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect();
buckets.sort_by(|a, b| b.doc_count.cmp(&a.doc_count));
AggregationResult::Bucket(BucketResult { buckets })
}
}
pub struct TopHitsAggFactory {
pub size: usize,
}
impl AggregatorFactory for TopHitsAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
Box::new(TopHitsCollector {
segment: reader as *const SegmentReader,
doc_ids: Vec::new(),
size: self.size,
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut all_hits: Vec<serde_json::Value> = Vec::new();
for r in results {
if let AggregationResult::Hits(h) = r {
all_hits.extend(h.hits);
}
}
all_hits.truncate(self.size);
AggregationResult::Hits(HitsResult { hits: all_hits })
}
}
struct TopHitsCollector {
segment: *const SegmentReader,
doc_ids: Vec<u32>,
size: usize,
}
unsafe impl Send for TopHitsCollector {}
impl Aggregator for TopHitsCollector {
fn collect(&mut self, doc_id: DocId) {
if self.doc_ids.len() < self.size {
self.doc_ids.push(doc_id.as_u32());
}
}
fn finish(self: Box<Self>) -> AggregationResult {
let reader = unsafe { &*self.segment };
let doc_store = reader.doc_store();
let mut hits = Vec::new();
for &doc_id in &self.doc_ids {
if let Some(source_bytes) = doc_store.get(doc_id) {
if let Ok(source) = serde_json::from_slice::<serde_json::Value>(&source_bytes) {
hits.push(serde_json::json!({
"_doc_id": doc_id,
"_source": source,
}));
}
}
}
AggregationResult::Hits(HitsResult { hits })
}
}
pub struct FilterAggFactory {
pub(crate) bound_query: Box<dyn crate::query::BoundQuery>,
pub(crate) sub_agg_factories: Vec<(String, Box<dyn AggregatorFactory>)>,
}
impl AggregatorFactory for FilterAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let doc_count = reader.doc_count() as usize;
let mut bitset = vec![false; doc_count];
if let Ok(Some(supplier)) = self.bound_query.scorer_supplier(reader) {
if let Ok(mut scorer) = supplier.scorer() {
while scorer.doc_id() != crate::core::NO_MORE_DOCS {
let id = scorer.doc_id().as_u32() as usize;
if id < doc_count {
bitset[id] = true;
}
scorer.next();
}
}
}
let sub_collectors: Vec<(String, Box<dyn Aggregator>)> = self
.sub_agg_factories
.iter()
.map(|(name, factory)| (name.clone(), factory.create_collector(reader)))
.collect();
Box::new(FilterCollector {
bitset,
count: 0,
sub_collectors,
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut total = 0u64;
let mut sub_partials: HashMap<String, Vec<AggregationResult>> = HashMap::new();
for r in results {
if let AggregationResult::Bucket(br) = r {
for b in br.buckets {
total += b.doc_count;
for (name, result) in b.sub_aggs {
sub_partials.entry(name).or_default().push(result);
}
}
}
}
let mut sub_aggs = HashMap::new();
for (name, factory) in &self.sub_agg_factories {
if let Some(partials) = sub_partials.remove(name) {
sub_aggs.insert(name.clone(), factory.merge_results(partials));
}
}
AggregationResult::Bucket(BucketResult {
buckets: vec![Bucket {
key: BucketKey::String("filter".into()),
doc_count: total,
sub_aggs,
}],
})
}
}
struct FilterCollector {
bitset: Vec<bool>,
count: u64,
sub_collectors: Vec<(String, Box<dyn Aggregator>)>,
}
pub struct FiltersAggFactory {
pub(crate) filters: Vec<(String, Box<dyn crate::query::BoundQuery>)>,
}
impl AggregatorFactory for FiltersAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let mut filter_bitsets: Vec<(String, Vec<bool>)> = Vec::new();
let doc_count = reader.doc_count() as usize;
for (name, bound_query) in &self.filters {
let mut bitset = vec![false; doc_count];
if let Ok(Some(supplier)) = bound_query.scorer_supplier(reader) {
if let Ok(mut scorer) = supplier.scorer() {
while scorer.doc_id() != crate::core::NO_MORE_DOCS {
let id = scorer.doc_id().as_u32() as usize;
if id < doc_count {
bitset[id] = true;
}
scorer.next();
}
}
}
filter_bitsets.push((name.clone(), bitset));
}
Box::new(FiltersCollector {
filter_bitsets,
counts: vec![0u64; self.filters.len()],
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let num_filters = self.filters.len();
let mut totals = vec![0u64; num_filters];
let names: Vec<String> = self.filters.iter().map(|(n, _)| n.clone()).collect();
for r in results {
if let AggregationResult::Bucket(br) = r {
for (i, b) in br.buckets.iter().enumerate() {
if i < num_filters {
totals[i] += b.doc_count;
}
}
}
}
AggregationResult::Bucket(BucketResult {
buckets: names
.iter()
.zip(totals.iter())
.map(|(name, &count)| Bucket {
key: BucketKey::String(name.clone()),
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect(),
})
}
}
struct FiltersCollector {
filter_bitsets: Vec<(String, Vec<bool>)>,
counts: Vec<u64>,
}
unsafe impl Send for FiltersCollector {}
impl Aggregator for FiltersCollector {
fn collect(&mut self, doc_id: DocId) {
let id = doc_id.as_u32() as usize;
for (i, (_, bitset)) in self.filter_bitsets.iter().enumerate() {
if id < bitset.len() && bitset[id] {
self.counts[i] += 1;
}
}
}
fn finish(self: Box<Self>) -> AggregationResult {
AggregationResult::Bucket(BucketResult {
buckets: self
.filter_bitsets
.iter()
.zip(self.counts.iter())
.map(|((name, _), &count)| Bucket {
key: BucketKey::String(name.clone()),
doc_count: count,
sub_aggs: HashMap::new(),
})
.collect(),
})
}
}
impl Aggregator for FilterCollector {
fn collect(&mut self, doc_id: DocId) {
let id = doc_id.as_u32() as usize;
if id < self.bitset.len() && self.bitset[id] {
self.count += 1;
for (_, collector) in &mut self.sub_collectors {
collector.collect(doc_id);
}
}
}
fn finish(self: Box<Self>) -> AggregationResult {
let mut sub_aggs = HashMap::new();
for (name, collector) in self.sub_collectors {
sub_aggs.insert(name, collector.finish());
}
AggregationResult::Bucket(BucketResult {
buckets: vec![Bucket {
key: BucketKey::String("filter".into()),
doc_count: self.count,
sub_aggs,
}],
})
}
}