use crate::core::DocId;
use super::{AggregationResult, Aggregator, AggregatorFactory, MetricResult};
use crate::segment::reader::SegmentReader;
pub struct MetricAggFactory {
pub field_name: String,
pub metric_type: MetricType,
}
#[derive(Clone, Copy)]
pub enum MetricType {
Avg,
Sum,
Min,
Max,
ValueCount,
Stats,
ExtendedStats,
}
impl AggregatorFactory for MetricAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field_name)
.map(|f| f.field_id);
if let Some(fid) = field_id {
if let Some(col) = reader.column(fid) {
if col.is_constant() {
let value = col.constant_value().unwrap();
let doc_count = col.doc_count();
let null_count = col.stats().map_or(0, |s| s.null_count);
return Box::new(ConstantMetricCollector {
value,
non_null_docs: doc_count - null_count,
collected: 0,
});
}
if matches!(self.metric_type, MetricType::Min | MetricType::Max) {
if let Some(stats) = col.stats() {
return Box::new(StatsMetricCollector {
min: stats.min,
max: stats.max,
doc_count: col.doc_count(),
collected: 0,
});
}
}
if matches!(self.metric_type, MetricType::ValueCount) {
if let Some(stats) = col.stats() {
return Box::new(ValueCountFastCollector {
doc_count: col.doc_count(),
non_null_count: col.doc_count() - stats.null_count,
collected: 0,
});
}
}
}
}
let col = super::bucket::OwnedColumn::new(field_id, reader);
Box::new(MetricCollector {
col,
sum: 0.0,
sum_of_squares: 0.0,
count: 0,
min: f64::INFINITY,
max: f64::NEG_INFINITY,
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut total_sum = 0.0f64;
let mut total_sum_of_squares = 0.0f64;
let mut total_count = 0u64;
let mut global_min = f64::INFINITY;
let mut global_max = f64::NEG_INFINITY;
for r in &results {
if let AggregationResult::Metric(m) = r {
let count = m.extra.get("count").copied().unwrap_or(0.0) as u64;
let sum = m.extra.get("sum").copied().unwrap_or(0.0);
let sum_sq = m.extra.get("sum_of_squares").copied().unwrap_or(0.0);
let min = m.extra.get("min").copied().unwrap_or(f64::INFINITY);
let max = m.extra.get("max").copied().unwrap_or(f64::NEG_INFINITY);
total_sum += sum;
total_sum_of_squares += sum_sq;
total_count += count;
if min < global_min {
global_min = min;
}
if max > global_max {
global_max = max;
}
}
}
if total_count == 0 {
return AggregationResult::Metric(MetricResult::single(None));
}
let avg = total_sum / total_count as f64;
match self.metric_type {
MetricType::Avg => AggregationResult::Metric(MetricResult::single(Some(avg))),
MetricType::Sum => AggregationResult::Metric(MetricResult::single(Some(total_sum))),
MetricType::Min => AggregationResult::Metric(MetricResult::single(Some(global_min))),
MetricType::Max => AggregationResult::Metric(MetricResult::single(Some(global_max))),
MetricType::ValueCount => {
AggregationResult::Metric(MetricResult::single(Some(total_count as f64)))
}
MetricType::Stats => AggregationResult::Metric(MetricResult::stats(
total_count,
global_min,
global_max,
avg,
total_sum,
)),
MetricType::ExtendedStats => {
let variance = (total_sum_of_squares / total_count as f64) - (avg * avg);
let std_dev = variance.max(0.0).sqrt();
let mut result =
MetricResult::stats(total_count, global_min, global_max, avg, total_sum);
result
.extra
.insert("sum_of_squares".into(), total_sum_of_squares);
result.extra.insert("variance".into(), variance);
result.extra.insert("std_deviation".into(), std_dev);
result
.extra
.insert("std_deviation_bounds.upper".into(), avg + 2.0 * std_dev);
result
.extra
.insert("std_deviation_bounds.lower".into(), avg - 2.0 * std_dev);
AggregationResult::Metric(result)
}
}
}
}
struct MetricCollector {
col: Option<super::bucket::OwnedColumn>,
sum: f64,
sum_of_squares: f64,
count: u64,
min: f64,
max: f64,
}
unsafe impl Send for MetricCollector {}
impl Aggregator for MetricCollector {
fn collect(&mut self, doc_id: DocId) {
let Some(v) = self
.col
.as_ref()
.and_then(|c| c.numeric_value(doc_id.as_u32()))
else {
return;
};
self.sum += v;
self.sum_of_squares += v * v;
self.count += 1;
if v < self.min {
self.min = v;
}
if v > self.max {
self.max = v;
}
}
fn collect_range(&mut self, start: u32, end: u32) {
let Some(col) = &self.col else { return };
for i in start..end {
if let Some(v) = col.numeric_value(i) {
self.sum += v;
self.sum_of_squares += v * v;
self.count += 1;
if v < self.min {
self.min = v;
}
if v > self.max {
self.max = v;
}
}
}
}
fn finish(self: Box<Self>) -> AggregationResult {
if self.count == 0 {
return AggregationResult::Metric(MetricResult::single(None));
}
let avg = self.sum / self.count as f64;
let mut result = MetricResult::stats(self.count, self.min, self.max, avg, self.sum);
result
.extra
.insert("sum_of_squares".into(), self.sum_of_squares);
AggregationResult::Metric(result)
}
}
struct StatsMetricCollector {
min: f64,
max: f64,
doc_count: u32,
collected: u64,
}
unsafe impl Send for StatsMetricCollector {}
impl Aggregator for StatsMetricCollector {
fn collect(&mut self, _doc_id: DocId) {
self.collected += 1;
}
fn finish(self: Box<Self>) -> AggregationResult {
let count = if self.doc_count == 0 {
0
} else {
self.collected
};
if count == 0 {
return AggregationResult::Metric(MetricResult::single(None));
}
AggregationResult::Metric(MetricResult::stats(count, self.min, self.max, 0.0, 0.0))
}
}
struct ValueCountFastCollector {
doc_count: u32,
non_null_count: u32,
collected: u64,
}
unsafe impl Send for ValueCountFastCollector {}
impl Aggregator for ValueCountFastCollector {
fn collect(&mut self, _doc_id: DocId) {
self.collected += 1;
}
fn finish(self: Box<Self>) -> AggregationResult {
let count = if self.collected as u32 >= self.doc_count {
self.non_null_count as u64
} else {
self.collected
};
let mut result = MetricResult::single(Some(count as f64));
result.extra.insert("count".into(), count as f64);
AggregationResult::Metric(result)
}
}
struct ConstantMetricCollector {
value: f64,
non_null_docs: u32,
collected: u64,
}
unsafe impl Send for ConstantMetricCollector {}
impl Aggregator for ConstantMetricCollector {
fn collect(&mut self, _doc_id: DocId) {
self.collected += 1;
}
fn finish(self: Box<Self>) -> AggregationResult {
let count = if self.non_null_docs == 0 {
0
} else {
self.collected
};
if count == 0 {
return AggregationResult::Metric(MetricResult::single(None));
}
let sum = self.value * count as f64;
AggregationResult::Metric(MetricResult::stats(
count, self.value, self.value, self.value, sum,
))
}
}
pub struct GeoBoundsAggFactory {
pub field_name: String,
}
impl AggregatorFactory for GeoBoundsAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field_name)
.map(|f| f.field_id);
let store = field_id.and_then(|fid| reader.geo_points(fid));
Box::new(GeoBoundsCollector {
store,
min_lat: f64::INFINITY,
max_lat: f64::NEG_INFINITY,
min_lon: f64::INFINITY,
max_lon: f64::NEG_INFINITY,
count: 0,
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut min_lat = f64::INFINITY;
let mut max_lat = f64::NEG_INFINITY;
let mut min_lon = f64::INFINITY;
let mut max_lon = f64::NEG_INFINITY;
let mut count = 0u64;
for r in &results {
if let AggregationResult::Metric(m) = r {
if let Some(&c) = m.extra.get("count") {
if c > 0.0 {
count += c as u64;
if let Some(&v) = m.extra.get("top_left.lat") {
if v > max_lat {
max_lat = v;
}
}
if let Some(&v) = m.extra.get("bottom_right.lat") {
if v < min_lat {
min_lat = v;
}
}
if let Some(&v) = m.extra.get("top_left.lon") {
if v < min_lon {
min_lon = v;
}
}
if let Some(&v) = m.extra.get("bottom_right.lon") {
if v > max_lon {
max_lon = v;
}
}
}
}
}
}
if count == 0 {
return AggregationResult::Metric(MetricResult::single(None));
}
let mut result = MetricResult::single(None);
result.extra.insert("count".into(), count as f64);
result.extra.insert("top_left.lat".into(), max_lat);
result.extra.insert("top_left.lon".into(), min_lon);
result.extra.insert("bottom_right.lat".into(), min_lat);
result.extra.insert("bottom_right.lon".into(), max_lon);
AggregationResult::Metric(result)
}
}
struct GeoBoundsCollector {
store: Option<crate::spatial::geo::GeoPointStore>,
min_lat: f64,
max_lat: f64,
min_lon: f64,
max_lon: f64,
count: u64,
}
unsafe impl Send for GeoBoundsCollector {}
impl Aggregator for GeoBoundsCollector {
fn collect(&mut self, doc_id: DocId) {
if let Some(store) = &self.store {
if let Some(point) = store.get(doc_id.as_u32()) {
if point.lat < self.min_lat {
self.min_lat = point.lat;
}
if point.lat > self.max_lat {
self.max_lat = point.lat;
}
if point.lon < self.min_lon {
self.min_lon = point.lon;
}
if point.lon > self.max_lon {
self.max_lon = point.lon;
}
self.count += 1;
}
}
}
fn finish(self: Box<Self>) -> AggregationResult {
if self.count == 0 {
return AggregationResult::Metric(MetricResult::single(None));
}
let mut result = MetricResult::single(None);
result.extra.insert("count".into(), self.count as f64);
result.extra.insert("top_left.lat".into(), self.max_lat);
result.extra.insert("top_left.lon".into(), self.min_lon);
result.extra.insert("bottom_right.lat".into(), self.min_lat);
result.extra.insert("bottom_right.lon".into(), self.max_lon);
AggregationResult::Metric(result)
}
}
pub struct GeoCentroidAggFactory {
pub field_name: String,
}
impl AggregatorFactory for GeoCentroidAggFactory {
fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field_name)
.map(|f| f.field_id);
let store = field_id.and_then(|fid| reader.geo_points(fid));
Box::new(GeoCentroidCollector {
store,
sum_lat: 0.0,
sum_lon: 0.0,
count: 0,
})
}
fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
let mut total_sum_lat = 0.0f64;
let mut total_sum_lon = 0.0f64;
let mut total_count = 0u64;
for r in &results {
if let AggregationResult::Metric(m) = r {
let count = m.extra.get("count").copied().unwrap_or(0.0) as u64;
let sum_lat = m.extra.get("sum_lat").copied().unwrap_or(0.0);
let sum_lon = m.extra.get("sum_lon").copied().unwrap_or(0.0);
total_count += count;
total_sum_lat += sum_lat;
total_sum_lon += sum_lon;
}
}
if total_count == 0 {
return AggregationResult::Metric(MetricResult::single(None));
}
let mut result = MetricResult::single(None);
result.extra.insert("count".into(), total_count as f64);
result
.extra
.insert("lat".into(), total_sum_lat / total_count as f64);
result
.extra
.insert("lon".into(), total_sum_lon / total_count as f64);
result.extra.insert("sum_lat".into(), total_sum_lat);
result.extra.insert("sum_lon".into(), total_sum_lon);
AggregationResult::Metric(result)
}
}
struct GeoCentroidCollector {
store: Option<crate::spatial::geo::GeoPointStore>,
sum_lat: f64,
sum_lon: f64,
count: u64,
}
unsafe impl Send for GeoCentroidCollector {}
impl Aggregator for GeoCentroidCollector {
fn collect(&mut self, doc_id: DocId) {
if let Some(store) = &self.store {
if let Some(point) = store.get(doc_id.as_u32()) {
self.sum_lat += point.lat;
self.sum_lon += point.lon;
self.count += 1;
}
}
}
fn finish(self: Box<Self>) -> AggregationResult {
if self.count == 0 {
return AggregationResult::Metric(MetricResult::single(None));
}
let mut result = MetricResult::single(None);
result.extra.insert("count".into(), self.count as f64);
result
.extra
.insert("lat".into(), self.sum_lat / self.count as f64);
result
.extra
.insert("lon".into(), self.sum_lon / self.count as f64);
result.extra.insert("sum_lat".into(), self.sum_lat);
result.extra.insert("sum_lon".into(), self.sum_lon);
AggregationResult::Metric(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn merge_avg() {
let factory = MetricAggFactory {
field_name: "price".into(),
metric_type: MetricType::Avg,
};
let results = vec![
AggregationResult::Metric(MetricResult::stats(3, 1.0, 3.0, 2.0, 6.0)),
AggregationResult::Metric(MetricResult::stats(2, 4.0, 5.0, 4.5, 9.0)),
];
let merged = factory.merge_results(results);
if let AggregationResult::Metric(m) = merged {
assert_eq!(m.value, Some(3.0)); } else {
panic!();
}
}
#[test]
fn merge_sum() {
let factory = MetricAggFactory {
field_name: "x".into(),
metric_type: MetricType::Sum,
};
let results = vec![
AggregationResult::Metric(MetricResult::stats(2, 0.0, 0.0, 0.0, 10.0)),
AggregationResult::Metric(MetricResult::stats(3, 0.0, 0.0, 0.0, 20.0)),
];
let merged = factory.merge_results(results);
if let AggregationResult::Metric(m) = merged {
assert_eq!(m.value, Some(30.0));
} else {
panic!();
}
}
#[test]
fn merge_min_max() {
let factory = MetricAggFactory {
field_name: "x".into(),
metric_type: MetricType::Min,
};
let results = vec![
AggregationResult::Metric(MetricResult::stats(1, 5.0, 5.0, 5.0, 5.0)),
AggregationResult::Metric(MetricResult::stats(1, 2.0, 2.0, 2.0, 2.0)),
];
let merged = factory.merge_results(results);
if let AggregationResult::Metric(m) = merged {
assert_eq!(m.value, Some(2.0));
} else {
panic!();
}
}
#[test]
fn merge_empty() {
let factory = MetricAggFactory {
field_name: "x".into(),
metric_type: MetricType::Avg,
};
let merged =
factory.merge_results(vec![AggregationResult::Metric(MetricResult::single(None))]);
if let AggregationResult::Metric(m) = merged {
assert!(m.value.is_none());
} else {
panic!();
}
}
}