summavy 0.25.3

Search engine library
Documentation
//! Contains the aggregation request tree. Used to build an
//! [`AggregationCollector`](super::AggregationCollector).
//!
//! [`Aggregations`] is the top level entry point to create a request, which is a `HashMap<String,
//! Aggregation>`.
//!
//! Requests are compatible with the json format of elasticsearch.
//!
//! # Example
//!
//! ```
//! use tantivy::aggregation::bucket::RangeAggregation;
//! use tantivy::aggregation::agg_req::BucketAggregationType;
//! use tantivy::aggregation::agg_req::{Aggregation, Aggregations};
//! use tantivy::aggregation::agg_req::BucketAggregation;
//! let agg_req1: Aggregations = vec![
//!     (
//!         "range".to_string(),
//!         Aggregation::Bucket(BucketAggregation {
//!             bucket_agg: BucketAggregationType::Range(RangeAggregation{
//!                 field: "score".to_string(),
//!                 ranges: vec![(3f64..7f64).into(), (7f64..20f64).into()],
//!                 keyed: false,
//!             }),
//!             sub_aggregation: Default::default(),
//!         }),
//!     ),
//! ]
//! .into_iter()
//! .collect();
//!
//! let elasticsearch_compatible_json_req = r#"
//! {
//!   "range": {
//!     "range": {
//!       "field": "score",
//!       "ranges": [
//!         { "from": 3.0, "to": 7.0 },
//!         { "from": 7.0, "to": 20.0 }
//!       ]
//!     }
//!   }
//! }"#;
//! let agg_req2: Aggregations = serde_json::from_str(elasticsearch_compatible_json_req).unwrap();
//! assert_eq!(agg_req1, agg_req2);
//! ```

use std::collections::{HashMap, HashSet};

use serde::{Deserialize, Serialize};

pub use super::bucket::RangeAggregation;
use super::bucket::{HistogramAggregation, TermsAggregation};
use super::metric::{AverageAggregation, StatsAggregation};
use super::VecWithNames;

/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
/// defined names. It is also used in [buckets](BucketAggregation) to define sub-aggregations.
///
/// The key is the user defined name of the aggregation.
pub type Aggregations = HashMap<String, Aggregation>;

/// Like Aggregations, but optimized to work with the aggregation result
#[derive(Clone, Debug)]
pub(crate) struct AggregationsInternal {
    pub(crate) metrics: VecWithNames<MetricAggregation>,
    pub(crate) buckets: VecWithNames<BucketAggregationInternal>,
}

impl From<Aggregations> for AggregationsInternal {
    fn from(aggs: Aggregations) -> Self {
        let mut metrics = vec![];
        let mut buckets = vec![];
        for (key, agg) in aggs {
            match agg {
                Aggregation::Bucket(bucket) => buckets.push((
                    key,
                    BucketAggregationInternal {
                        bucket_agg: bucket.bucket_agg,
                        sub_aggregation: bucket.sub_aggregation.into(),
                    },
                )),
                Aggregation::Metric(metric) => metrics.push((key, metric)),
            }
        }
        Self {
            metrics: VecWithNames::from_entries(metrics),
            buckets: VecWithNames::from_entries(buckets),
        }
    }
}

#[derive(Clone, Debug)]
// Like BucketAggregation, but optimized to work with the result
pub(crate) struct BucketAggregationInternal {
    /// Bucket aggregation strategy to group documents.
    pub bucket_agg: BucketAggregationType,
    /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the
    /// bucket.
    pub sub_aggregation: AggregationsInternal,
}

impl BucketAggregationInternal {
    pub(crate) fn as_range(&self) -> Option<&RangeAggregation> {
        match &self.bucket_agg {
            BucketAggregationType::Range(range) => Some(range),
            _ => None,
        }
    }
    pub(crate) fn as_histogram(&self) -> Option<&HistogramAggregation> {
        match &self.bucket_agg {
            BucketAggregationType::Histogram(histogram) => Some(histogram),
            _ => None,
        }
    }
    pub(crate) fn as_term(&self) -> Option<&TermsAggregation> {
        match &self.bucket_agg {
            BucketAggregationType::Terms(terms) => Some(terms),
            _ => None,
        }
    }
}

/// Extract all fields, where the term directory is used in the tree.
pub fn get_term_dict_field_names(aggs: &Aggregations) -> HashSet<String> {
    let mut term_dict_field_names = Default::default();
    for el in aggs.values() {
        el.get_term_dict_field_names(&mut term_dict_field_names)
    }
    term_dict_field_names
}

/// Extract all fast field names used in the tree.
pub fn get_fast_field_names(aggs: &Aggregations) -> HashSet<String> {
    let mut fast_field_names = Default::default();
    for el in aggs.values() {
        el.get_fast_field_names(&mut fast_field_names)
    }
    fast_field_names
}

/// Aggregation request of [`BucketAggregation`] or [`MetricAggregation`].
///
/// An aggregation is either a bucket or a metric.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Aggregation {
    /// Bucket aggregation, see [`BucketAggregation`] for details.
    Bucket(BucketAggregation),
    /// Metric aggregation, see [`MetricAggregation`] for details.
    Metric(MetricAggregation),
}

impl Aggregation {
    fn get_term_dict_field_names(&self, term_field_names: &mut HashSet<String>) {
        if let Aggregation::Bucket(bucket) = self {
            bucket.get_term_dict_field_names(term_field_names)
        }
    }

    fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
        match self {
            Aggregation::Bucket(bucket) => bucket.get_fast_field_names(fast_field_names),
            Aggregation::Metric(metric) => metric.get_fast_field_names(fast_field_names),
        }
    }
}

/// BucketAggregations create buckets of documents. Each bucket is associated with a rule which
/// determines whether or not a document in the falls into it. In other words, the buckets
/// effectively define document sets. Buckets are not necessarily disjunct, therefore a document can
/// fall into multiple buckets. In addition to the buckets themselves, the bucket aggregations also
/// compute and return the number of documents for each bucket. Bucket aggregations, as opposed to
/// metric aggregations, can hold sub-aggregations. These sub-aggregations will be aggregated for
/// the buckets created by their "parent" bucket aggregation. There are different bucket
/// aggregators, each with a different "bucketing" strategy. Some define a single bucket, some
/// define fixed number of multiple buckets, and others dynamically create the buckets during the
/// aggregation process.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BucketAggregation {
    /// Bucket aggregation strategy to group documents.
    #[serde(flatten)]
    pub bucket_agg: BucketAggregationType,
    /// The sub_aggregations in the buckets. Each bucket will aggregate on the document set in the
    /// bucket.
    #[serde(rename = "aggs")]
    #[serde(default)]
    #[serde(skip_serializing_if = "Aggregations::is_empty")]
    pub sub_aggregation: Aggregations,
}

impl BucketAggregation {
    fn get_term_dict_field_names(&self, term_dict_field_names: &mut HashSet<String>) {
        if let BucketAggregationType::Terms(terms) = &self.bucket_agg {
            term_dict_field_names.insert(terms.field.to_string());
        }
        term_dict_field_names.extend(get_term_dict_field_names(&self.sub_aggregation));
    }
    fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
        self.bucket_agg.get_fast_field_names(fast_field_names);
        fast_field_names.extend(get_fast_field_names(&self.sub_aggregation));
    }
}

/// The bucket aggregation types.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum BucketAggregationType {
    /// Put data into buckets of user-defined ranges.
    #[serde(rename = "range")]
    Range(RangeAggregation),
    /// Put data into buckets of user-defined ranges.
    #[serde(rename = "histogram")]
    Histogram(HistogramAggregation),
    /// Put data into buckets of terms.
    #[serde(rename = "terms")]
    Terms(TermsAggregation),
}

impl BucketAggregationType {
    fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
        match self {
            BucketAggregationType::Terms(terms) => fast_field_names.insert(terms.field.to_string()),
            BucketAggregationType::Range(range) => fast_field_names.insert(range.field.to_string()),
            BucketAggregationType::Histogram(histogram) => {
                fast_field_names.insert(histogram.field.to_string())
            }
        };
    }
}

/// The aggregations in this family compute metrics based on values extracted
/// from the documents that are being aggregated. Values are extracted from the fast field of
/// the document.

/// Some aggregations output a single numeric metric (e.g. Average) and are called
/// single-value numeric metrics aggregation, others generate multiple metrics (e.g. Stats) and are
/// called multi-value numeric metrics aggregation.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum MetricAggregation {
    /// Calculates the average.
    #[serde(rename = "avg")]
    Average(AverageAggregation),
    /// Calculates stats sum, average, min, max, standard_deviation on a field.
    #[serde(rename = "stats")]
    Stats(StatsAggregation),
}

impl MetricAggregation {
    fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
        match self {
            MetricAggregation::Average(avg) => fast_field_names.insert(avg.field.to_string()),
            MetricAggregation::Stats(stats) => fast_field_names.insert(stats.field.to_string()),
        };
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn serialize_to_json_test() {
        let agg_req1: Aggregations = vec![(
            "range".to_string(),
            Aggregation::Bucket(BucketAggregation {
                bucket_agg: BucketAggregationType::Range(RangeAggregation {
                    field: "score".to_string(),
                    ranges: vec![
                        (f64::MIN..3f64).into(),
                        (3f64..7f64).into(),
                        (7f64..20f64).into(),
                        (20f64..f64::MAX).into(),
                    ],
                    keyed: true,
                }),
                sub_aggregation: Default::default(),
            }),
        )]
        .into_iter()
        .collect();

        let elasticsearch_compatible_json_req = r#"{
  "range": {
    "range": {
      "field": "score",
      "ranges": [
        {
          "to": 3.0
        },
        {
          "from": 3.0,
          "to": 7.0
        },
        {
          "from": 7.0,
          "to": 20.0
        },
        {
          "from": 20.0
        }
      ],
      "keyed": true
    }
  }
}"#;
        let agg_req2: String = serde_json::to_string_pretty(&agg_req1).unwrap();
        assert_eq!(agg_req2, elasticsearch_compatible_json_req);
    }

    #[test]
    fn test_get_fast_field_names() {
        let agg_req2: Aggregations = vec![
            (
                "range".to_string(),
                Aggregation::Bucket(BucketAggregation {
                    bucket_agg: BucketAggregationType::Range(RangeAggregation {
                        field: "score2".to_string(),
                        ranges: vec![
                            (f64::MIN..3f64).into(),
                            (3f64..7f64).into(),
                            (7f64..20f64).into(),
                            (20f64..f64::MAX).into(),
                        ],
                        ..Default::default()
                    }),
                    sub_aggregation: Default::default(),
                }),
            ),
            (
                "metric".to_string(),
                Aggregation::Metric(MetricAggregation::Average(
                    AverageAggregation::from_field_name("field123".to_string()),
                )),
            ),
        ]
        .into_iter()
        .collect();

        let agg_req1: Aggregations = vec![(
            "range".to_string(),
            Aggregation::Bucket(BucketAggregation {
                bucket_agg: BucketAggregationType::Range(RangeAggregation {
                    field: "score".to_string(),
                    ranges: vec![
                        (f64::MIN..3f64).into(),
                        (3f64..7f64).into(),
                        (7f64..20f64).into(),
                        (20f64..f64::MAX).into(),
                    ],
                    ..Default::default()
                }),
                sub_aggregation: agg_req2,
            }),
        )]
        .into_iter()
        .collect();

        assert_eq!(
            get_fast_field_names(&agg_req1),
            vec![
                "score".to_string(),
                "score2".to_string(),
                "field123".to_string()
            ]
            .into_iter()
            .collect()
        )
    }
}