use crate::{search::*, util::*};
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct DiversifiedSamplerAggregation {
diversified_sampler: DiversifiedSamplerAggregationInner,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
aggs: Aggregations,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Copy)]
#[serde(rename_all = "snake_case")]
pub enum ExecutionHint {
Map,
BytesHash,
GlobalOrdinals,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
struct DiversifiedSamplerAggregationInner {
field: String,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
shard_size: Option<u64>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
max_docs_per_value: Option<u64>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
execution_hint: Option<ExecutionHint>,
}
impl Aggregation {
pub fn diversified_sampler<T>(field: T) -> DiversifiedSamplerAggregation
where
T: ToString,
{
DiversifiedSamplerAggregation {
diversified_sampler: DiversifiedSamplerAggregationInner {
field: field.to_string(),
shard_size: None,
max_docs_per_value: None,
execution_hint: None,
},
aggs: Aggregations::new(),
}
}
}
impl DiversifiedSamplerAggregation {
pub fn shard_size(mut self, shard_size: u64) -> Self {
self.diversified_sampler.shard_size = Some(shard_size);
self
}
pub fn max_docs_per_value(mut self, max_docs_per_value: u64) -> Self {
self.diversified_sampler.max_docs_per_value = Some(max_docs_per_value);
self
}
pub fn execution_hint(mut self, execution_hint: ExecutionHint) -> Self {
self.diversified_sampler.execution_hint = Some(execution_hint);
self
}
add_aggregate!();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn serialization() {
assert_serialize_aggregation(
Aggregation::diversified_sampler("catalog_id").shard_size(50),
json!({
"diversified_sampler": {
"field": "catalog_id",
"shard_size": 50
}
}),
);
assert_serialize_aggregation(
Aggregation::diversified_sampler("catalog_id")
.shard_size(50)
.max_docs_per_value(2)
.execution_hint(ExecutionHint::GlobalOrdinals)
.aggregate("catalog", Aggregation::terms("catalog_id"))
.aggregate("brand", Aggregation::terms("brand_id")),
json!({
"diversified_sampler": {
"field": "catalog_id",
"shard_size": 50,
"max_docs_per_value": 2,
"execution_hint": "global_ordinals"
},
"aggs": {
"catalog": {
"terms": {
"field": "catalog_id"
}
},
"brand": {
"terms": {
"field": "brand_id"
}
}
}
}),
);
}
}