1use crate::core::DocId;
6
7use super::{AggregationResult, Aggregator, AggregatorFactory, MetricResult};
8use crate::segment::reader::SegmentReader;
9
10pub struct MetricAggFactory {
12 pub field_name: String,
13 pub metric_type: MetricType,
14}
15
16#[derive(Clone, Copy)]
17pub enum MetricType {
18 Avg,
19 Sum,
20 Min,
21 Max,
22 ValueCount,
23 Stats,
24 ExtendedStats,
25}
26
27impl AggregatorFactory for MetricAggFactory {
28 fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
29 let field_id = reader
30 .header()
31 .fields
32 .iter()
33 .find(|f| f.field_name == self.field_name)
34 .map(|f| f.field_id);
35
36 if let Some(fid) = field_id {
37 if let Some(col) = reader.column(fid) {
38 if col.is_constant() {
40 let value = col.constant_value().unwrap();
41 let doc_count = col.doc_count();
42 let null_count = col.stats().map_or(0, |s| s.null_count);
43 return Box::new(ConstantMetricCollector {
44 value,
45 non_null_docs: doc_count - null_count,
46 collected: 0,
47 });
48 }
49
50 if matches!(self.metric_type, MetricType::Min | MetricType::Max) {
52 if let Some(stats) = col.stats() {
53 return Box::new(StatsMetricCollector {
54 min: stats.min,
55 max: stats.max,
56 doc_count: col.doc_count(),
57 collected: 0,
58 });
59 }
60 }
61
62 if matches!(self.metric_type, MetricType::ValueCount) {
64 if let Some(stats) = col.stats() {
65 return Box::new(ValueCountFastCollector {
66 doc_count: col.doc_count(),
67 non_null_count: col.doc_count() - stats.null_count,
68 collected: 0,
69 });
70 }
71 }
72 }
73 }
74
75 let col = super::bucket::OwnedColumn::new(field_id, reader);
76
77 Box::new(MetricCollector {
78 col,
79 sum: 0.0,
80 sum_of_squares: 0.0,
81 count: 0,
82 min: f64::INFINITY,
83 max: f64::NEG_INFINITY,
84 })
85 }
86
87 fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
88 let mut total_sum = 0.0f64;
89 let mut total_sum_of_squares = 0.0f64;
90 let mut total_count = 0u64;
91 let mut global_min = f64::INFINITY;
92 let mut global_max = f64::NEG_INFINITY;
93
94 for r in &results {
95 if let AggregationResult::Metric(m) = r {
96 let count = m.extra.get("count").copied().unwrap_or(0.0) as u64;
97 let sum = m.extra.get("sum").copied().unwrap_or(0.0);
98 let sum_sq = m.extra.get("sum_of_squares").copied().unwrap_or(0.0);
99 let min = m.extra.get("min").copied().unwrap_or(f64::INFINITY);
100 let max = m.extra.get("max").copied().unwrap_or(f64::NEG_INFINITY);
101 total_sum += sum;
102 total_sum_of_squares += sum_sq;
103 total_count += count;
104 if min < global_min {
105 global_min = min;
106 }
107 if max > global_max {
108 global_max = max;
109 }
110 }
111 }
112
113 if total_count == 0 {
114 return AggregationResult::Metric(MetricResult::single(None));
115 }
116
117 let avg = total_sum / total_count as f64;
118
119 match self.metric_type {
120 MetricType::Avg => AggregationResult::Metric(MetricResult::single(Some(avg))),
121 MetricType::Sum => AggregationResult::Metric(MetricResult::single(Some(total_sum))),
122 MetricType::Min => AggregationResult::Metric(MetricResult::single(Some(global_min))),
123 MetricType::Max => AggregationResult::Metric(MetricResult::single(Some(global_max))),
124 MetricType::ValueCount => {
125 AggregationResult::Metric(MetricResult::single(Some(total_count as f64)))
126 }
127 MetricType::Stats => AggregationResult::Metric(MetricResult::stats(
128 total_count,
129 global_min,
130 global_max,
131 avg,
132 total_sum,
133 )),
134 MetricType::ExtendedStats => {
135 let variance = (total_sum_of_squares / total_count as f64) - (avg * avg);
136 let std_dev = variance.max(0.0).sqrt();
137 let mut result =
138 MetricResult::stats(total_count, global_min, global_max, avg, total_sum);
139 result
140 .extra
141 .insert("sum_of_squares".into(), total_sum_of_squares);
142 result.extra.insert("variance".into(), variance);
143 result.extra.insert("std_deviation".into(), std_dev);
144 result
145 .extra
146 .insert("std_deviation_bounds.upper".into(), avg + 2.0 * std_dev);
147 result
148 .extra
149 .insert("std_deviation_bounds.lower".into(), avg - 2.0 * std_dev);
150 AggregationResult::Metric(result)
151 }
152 }
153 }
154}
155
156struct MetricCollector {
157 col: Option<super::bucket::OwnedColumn>,
158 sum: f64,
159 sum_of_squares: f64,
160 count: u64,
161 min: f64,
162 max: f64,
163}
164
165unsafe impl Send for MetricCollector {}
166
167impl Aggregator for MetricCollector {
168 fn collect(&mut self, doc_id: DocId) {
169 let Some(v) = self
170 .col
171 .as_ref()
172 .and_then(|c| c.numeric_value(doc_id.as_u32()))
173 else {
174 return;
175 };
176
177 self.sum += v;
178 self.sum_of_squares += v * v;
179 self.count += 1;
180 if v < self.min {
181 self.min = v;
182 }
183 if v > self.max {
184 self.max = v;
185 }
186 }
187
188 fn collect_range(&mut self, start: u32, end: u32) {
189 let Some(col) = &self.col else { return };
190 for i in start..end {
191 if let Some(v) = col.numeric_value(i) {
192 self.sum += v;
193 self.sum_of_squares += v * v;
194 self.count += 1;
195 if v < self.min {
196 self.min = v;
197 }
198 if v > self.max {
199 self.max = v;
200 }
201 }
202 }
203 }
204
205 fn finish(self: Box<Self>) -> AggregationResult {
206 if self.count == 0 {
207 return AggregationResult::Metric(MetricResult::single(None));
208 }
209 let avg = self.sum / self.count as f64;
210 let mut result = MetricResult::stats(self.count, self.min, self.max, avg, self.sum);
212 result
213 .extra
214 .insert("sum_of_squares".into(), self.sum_of_squares);
215 AggregationResult::Metric(result)
216 }
217}
218
219struct StatsMetricCollector {
223 min: f64,
224 max: f64,
225 doc_count: u32,
226 collected: u64,
227}
228
229unsafe impl Send for StatsMetricCollector {}
230
231impl Aggregator for StatsMetricCollector {
232 fn collect(&mut self, _doc_id: DocId) {
233 self.collected += 1;
234 }
235
236 fn finish(self: Box<Self>) -> AggregationResult {
237 let count = if self.doc_count == 0 {
238 0
239 } else {
240 self.collected
241 };
242 if count == 0 {
243 return AggregationResult::Metric(MetricResult::single(None));
244 }
245 AggregationResult::Metric(MetricResult::stats(count, self.min, self.max, 0.0, 0.0))
248 }
249}
250
251struct ValueCountFastCollector {
256 doc_count: u32,
257 non_null_count: u32,
258 collected: u64,
259}
260
261unsafe impl Send for ValueCountFastCollector {}
262
263impl Aggregator for ValueCountFastCollector {
264 fn collect(&mut self, _doc_id: DocId) {
265 self.collected += 1;
266 }
267
268 fn finish(self: Box<Self>) -> AggregationResult {
269 let count = if self.collected as u32 >= self.doc_count {
270 self.non_null_count as u64
271 } else {
272 self.collected
273 };
274 let mut result = MetricResult::single(Some(count as f64));
276 result.extra.insert("count".into(), count as f64);
277 AggregationResult::Metric(result)
278 }
279}
280
281struct ConstantMetricCollector {
282 value: f64,
283 non_null_docs: u32,
285 collected: u64,
287}
288
289unsafe impl Send for ConstantMetricCollector {}
290
291impl Aggregator for ConstantMetricCollector {
292 fn collect(&mut self, _doc_id: DocId) {
293 self.collected += 1;
296 }
297
298 fn finish(self: Box<Self>) -> AggregationResult {
299 let count = if self.non_null_docs == 0 {
305 0
306 } else {
307 self.collected
308 };
309
310 if count == 0 {
311 return AggregationResult::Metric(MetricResult::single(None));
312 }
313
314 let sum = self.value * count as f64;
315 AggregationResult::Metric(MetricResult::stats(
316 count, self.value, self.value, self.value, sum,
317 ))
318 }
319}
320
321pub struct GeoBoundsAggFactory {
324 pub field_name: String,
325}
326
327impl AggregatorFactory for GeoBoundsAggFactory {
328 fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
329 let field_id = reader
330 .header()
331 .fields
332 .iter()
333 .find(|f| f.field_name == self.field_name)
334 .map(|f| f.field_id);
335 let store = field_id.and_then(|fid| reader.geo_points(fid));
336 Box::new(GeoBoundsCollector {
337 store,
338 min_lat: f64::INFINITY,
339 max_lat: f64::NEG_INFINITY,
340 min_lon: f64::INFINITY,
341 max_lon: f64::NEG_INFINITY,
342 count: 0,
343 })
344 }
345
346 fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
347 let mut min_lat = f64::INFINITY;
348 let mut max_lat = f64::NEG_INFINITY;
349 let mut min_lon = f64::INFINITY;
350 let mut max_lon = f64::NEG_INFINITY;
351 let mut count = 0u64;
352
353 for r in &results {
354 if let AggregationResult::Metric(m) = r {
355 if let Some(&c) = m.extra.get("count") {
356 if c > 0.0 {
357 count += c as u64;
358 if let Some(&v) = m.extra.get("top_left.lat") {
359 if v > max_lat {
360 max_lat = v;
361 }
362 }
363 if let Some(&v) = m.extra.get("bottom_right.lat") {
364 if v < min_lat {
365 min_lat = v;
366 }
367 }
368 if let Some(&v) = m.extra.get("top_left.lon") {
369 if v < min_lon {
370 min_lon = v;
371 }
372 }
373 if let Some(&v) = m.extra.get("bottom_right.lon") {
374 if v > max_lon {
375 max_lon = v;
376 }
377 }
378 }
379 }
380 }
381 }
382
383 if count == 0 {
384 return AggregationResult::Metric(MetricResult::single(None));
385 }
386
387 let mut result = MetricResult::single(None);
388 result.extra.insert("count".into(), count as f64);
389 result.extra.insert("top_left.lat".into(), max_lat);
390 result.extra.insert("top_left.lon".into(), min_lon);
391 result.extra.insert("bottom_right.lat".into(), min_lat);
392 result.extra.insert("bottom_right.lon".into(), max_lon);
393 AggregationResult::Metric(result)
394 }
395}
396
397struct GeoBoundsCollector {
398 store: Option<crate::spatial::geo::GeoPointStore>,
399 min_lat: f64,
400 max_lat: f64,
401 min_lon: f64,
402 max_lon: f64,
403 count: u64,
404}
405
406unsafe impl Send for GeoBoundsCollector {}
407
408impl Aggregator for GeoBoundsCollector {
409 fn collect(&mut self, doc_id: DocId) {
410 if let Some(store) = &self.store {
411 if let Some(point) = store.get(doc_id.as_u32()) {
412 if point.lat < self.min_lat {
413 self.min_lat = point.lat;
414 }
415 if point.lat > self.max_lat {
416 self.max_lat = point.lat;
417 }
418 if point.lon < self.min_lon {
419 self.min_lon = point.lon;
420 }
421 if point.lon > self.max_lon {
422 self.max_lon = point.lon;
423 }
424 self.count += 1;
425 }
426 }
427 }
428
429 fn finish(self: Box<Self>) -> AggregationResult {
430 if self.count == 0 {
431 return AggregationResult::Metric(MetricResult::single(None));
432 }
433 let mut result = MetricResult::single(None);
434 result.extra.insert("count".into(), self.count as f64);
435 result.extra.insert("top_left.lat".into(), self.max_lat);
436 result.extra.insert("top_left.lon".into(), self.min_lon);
437 result.extra.insert("bottom_right.lat".into(), self.min_lat);
438 result.extra.insert("bottom_right.lon".into(), self.max_lon);
439 AggregationResult::Metric(result)
440 }
441}
442
443pub struct GeoCentroidAggFactory {
446 pub field_name: String,
447}
448
449impl AggregatorFactory for GeoCentroidAggFactory {
450 fn create_collector(&self, reader: &SegmentReader) -> Box<dyn Aggregator> {
451 let field_id = reader
452 .header()
453 .fields
454 .iter()
455 .find(|f| f.field_name == self.field_name)
456 .map(|f| f.field_id);
457 let store = field_id.and_then(|fid| reader.geo_points(fid));
458 Box::new(GeoCentroidCollector {
459 store,
460 sum_lat: 0.0,
461 sum_lon: 0.0,
462 count: 0,
463 })
464 }
465
466 fn merge_results(&self, results: Vec<AggregationResult>) -> AggregationResult {
467 let mut total_sum_lat = 0.0f64;
468 let mut total_sum_lon = 0.0f64;
469 let mut total_count = 0u64;
470
471 for r in &results {
472 if let AggregationResult::Metric(m) = r {
473 let count = m.extra.get("count").copied().unwrap_or(0.0) as u64;
474 let sum_lat = m.extra.get("sum_lat").copied().unwrap_or(0.0);
475 let sum_lon = m.extra.get("sum_lon").copied().unwrap_or(0.0);
476 total_count += count;
477 total_sum_lat += sum_lat;
478 total_sum_lon += sum_lon;
479 }
480 }
481
482 if total_count == 0 {
483 return AggregationResult::Metric(MetricResult::single(None));
484 }
485
486 let mut result = MetricResult::single(None);
487 result.extra.insert("count".into(), total_count as f64);
488 result
489 .extra
490 .insert("lat".into(), total_sum_lat / total_count as f64);
491 result
492 .extra
493 .insert("lon".into(), total_sum_lon / total_count as f64);
494 result.extra.insert("sum_lat".into(), total_sum_lat);
495 result.extra.insert("sum_lon".into(), total_sum_lon);
496 AggregationResult::Metric(result)
497 }
498}
499
500struct GeoCentroidCollector {
501 store: Option<crate::spatial::geo::GeoPointStore>,
502 sum_lat: f64,
503 sum_lon: f64,
504 count: u64,
505}
506
507unsafe impl Send for GeoCentroidCollector {}
508
509impl Aggregator for GeoCentroidCollector {
510 fn collect(&mut self, doc_id: DocId) {
511 if let Some(store) = &self.store {
512 if let Some(point) = store.get(doc_id.as_u32()) {
513 self.sum_lat += point.lat;
514 self.sum_lon += point.lon;
515 self.count += 1;
516 }
517 }
518 }
519
520 fn finish(self: Box<Self>) -> AggregationResult {
521 if self.count == 0 {
522 return AggregationResult::Metric(MetricResult::single(None));
523 }
524 let mut result = MetricResult::single(None);
525 result.extra.insert("count".into(), self.count as f64);
526 result
527 .extra
528 .insert("lat".into(), self.sum_lat / self.count as f64);
529 result
530 .extra
531 .insert("lon".into(), self.sum_lon / self.count as f64);
532 result.extra.insert("sum_lat".into(), self.sum_lat);
533 result.extra.insert("sum_lon".into(), self.sum_lon);
534 AggregationResult::Metric(result)
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
544 fn merge_avg() {
545 let factory = MetricAggFactory {
546 field_name: "price".into(),
547 metric_type: MetricType::Avg,
548 };
549 let results = vec![
550 AggregationResult::Metric(MetricResult::stats(3, 1.0, 3.0, 2.0, 6.0)),
551 AggregationResult::Metric(MetricResult::stats(2, 4.0, 5.0, 4.5, 9.0)),
552 ];
553 let merged = factory.merge_results(results);
554 if let AggregationResult::Metric(m) = merged {
555 assert_eq!(m.value, Some(3.0)); } else {
557 panic!();
558 }
559 }
560
561 #[test]
562 fn merge_sum() {
563 let factory = MetricAggFactory {
564 field_name: "x".into(),
565 metric_type: MetricType::Sum,
566 };
567 let results = vec![
568 AggregationResult::Metric(MetricResult::stats(2, 0.0, 0.0, 0.0, 10.0)),
569 AggregationResult::Metric(MetricResult::stats(3, 0.0, 0.0, 0.0, 20.0)),
570 ];
571 let merged = factory.merge_results(results);
572 if let AggregationResult::Metric(m) = merged {
573 assert_eq!(m.value, Some(30.0));
574 } else {
575 panic!();
576 }
577 }
578
579 #[test]
580 fn merge_min_max() {
581 let factory = MetricAggFactory {
582 field_name: "x".into(),
583 metric_type: MetricType::Min,
584 };
585 let results = vec![
586 AggregationResult::Metric(MetricResult::stats(1, 5.0, 5.0, 5.0, 5.0)),
587 AggregationResult::Metric(MetricResult::stats(1, 2.0, 2.0, 2.0, 2.0)),
588 ];
589 let merged = factory.merge_results(results);
590 if let AggregationResult::Metric(m) = merged {
591 assert_eq!(m.value, Some(2.0));
592 } else {
593 panic!();
594 }
595 }
596
597 #[test]
598 fn merge_empty() {
599 let factory = MetricAggFactory {
600 field_name: "x".into(),
601 metric_type: MetricType::Avg,
602 };
603 let merged =
604 factory.merge_results(vec![AggregationResult::Metric(MetricResult::single(None))]);
605 if let AggregationResult::Metric(m) = merged {
606 assert!(m.value.is_none());
607 } else {
608 panic!();
609 }
610 }
611}