Skip to main content

opensearch_dsl/search/response/
aggregation.rs

1/*
2 * Copyright 2023-2025 Alberto Paro
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#![allow(missing_docs)]
18
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21
22use super::Hit;
23use crate::{
24    Map, search::aggregations::Aggregation as RequestAggregation, search::params::GeoLocation,
25};
26
27/// Main aggregation trait equivalent
28pub trait AggregationTrait {
29    /// Meta information of aggregation
30    fn meta(&self) -> Option<&Value>;
31
32    /// Aggregation source
33    fn source_aggregation(&self) -> Option<&RequestAggregation>;
34
35    /// Set aggregation source
36    fn set_source_aggregation(&mut self, agg: Option<RequestAggregation>);
37
38    /// If the aggregation is empty
39    fn is_empty(&self) -> bool;
40
41    /// If the aggregation is not empty
42    fn non_empty(&self) -> bool {
43        !self.is_empty()
44    }
45
46    /// Extract label and count from an aggregation
47    fn extract_label_values(&self) -> (Vec<String>, Vec<f64>) {
48        (vec![], vec![])
49    }
50}
51
52/// Main aggregation response enum
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
54#[serde(untagged)]
55pub enum AggregationResponse {
56    /// Bucket aggregation response
57    Bucket(BucketAggregation),
58    /// Multi-bucket aggregation response  
59    MultiBucket(MultiBucketAggregation),
60    /// Document count aggregation response
61    DocCount(DocCountAggregation),
62    /// Geo bounds aggregation response
63    GeoBounds(GeoBoundsValue),
64    /// Extended metric statistics aggregation response
65    MetricExtendedStats(MetricExtendedStats),
66    /// Metric statistics aggregation response
67    MetricStats(MetricStats),
68    /// Metric value aggregation response
69    MetricValue(MetricValue),
70    /// Top hits aggregation response
71    TopHits(TopHitsStats),
72    /// Simple aggregation response
73    Simple(Simple),
74}
75
76impl AggregationTrait for AggregationResponse {
77    fn meta(&self) -> Option<&Value> {
78        match self {
79            AggregationResponse::Bucket(agg) => agg.meta.as_ref(),
80            AggregationResponse::MultiBucket(agg) => agg.meta.as_ref(),
81            AggregationResponse::DocCount(agg) => agg.meta.as_ref(),
82            AggregationResponse::GeoBounds(agg) => agg.meta.as_ref(),
83            AggregationResponse::MetricExtendedStats(agg) => agg.meta.as_ref(),
84            AggregationResponse::MetricStats(agg) => agg.meta.as_ref(),
85            AggregationResponse::MetricValue(agg) => agg.meta.as_ref(),
86            AggregationResponse::TopHits(agg) => agg.meta.as_ref(),
87            AggregationResponse::Simple(agg) => agg.meta.as_ref(),
88        }
89    }
90
91    fn source_aggregation(&self) -> Option<&RequestAggregation> {
92        match self {
93            AggregationResponse::Bucket(agg) => agg.source_aggregation.as_ref(),
94            AggregationResponse::MultiBucket(agg) => agg.source_aggregation.as_ref(),
95            AggregationResponse::DocCount(agg) => agg.source_aggregation.as_ref(),
96            AggregationResponse::GeoBounds(agg) => agg.source_aggregation.as_ref(),
97            AggregationResponse::MetricExtendedStats(agg) => agg.source_aggregation.as_ref(),
98            AggregationResponse::MetricStats(agg) => agg.source_aggregation.as_ref(),
99            AggregationResponse::MetricValue(agg) => agg.source_aggregation.as_ref(),
100            AggregationResponse::TopHits(agg) => agg.source_aggregation.as_ref(),
101            AggregationResponse::Simple(agg) => agg.source_aggregation.as_ref(),
102        }
103    }
104
105    fn set_source_aggregation(&mut self, agg: Option<RequestAggregation>) {
106        match self {
107            AggregationResponse::Bucket(a) => a.source_aggregation = agg,
108            AggregationResponse::MultiBucket(a) => a.source_aggregation = agg,
109            AggregationResponse::DocCount(a) => a.source_aggregation = agg,
110            AggregationResponse::GeoBounds(a) => a.source_aggregation = agg,
111            AggregationResponse::MetricExtendedStats(a) => a.source_aggregation = agg,
112            AggregationResponse::MetricStats(a) => a.source_aggregation = agg,
113            AggregationResponse::MetricValue(a) => a.source_aggregation = agg,
114            AggregationResponse::TopHits(a) => a.source_aggregation = agg,
115            AggregationResponse::Simple(a) => a.source_aggregation = agg,
116        }
117    }
118
119    fn is_empty(&self) -> bool {
120        match self {
121            AggregationResponse::Bucket(agg) => agg.buckets.is_empty(),
122            AggregationResponse::MultiBucket(_) => true,
123            AggregationResponse::DocCount(_) => false,
124            AggregationResponse::GeoBounds(_) => false,
125            AggregationResponse::MetricExtendedStats(_) => false,
126            AggregationResponse::MetricStats(_) => false,
127            AggregationResponse::MetricValue(_) => false,
128            AggregationResponse::TopHits(_) => false,
129            AggregationResponse::Simple(_) => false,
130        }
131    }
132
133    fn extract_label_values(&self) -> (Vec<String>, Vec<f64>) {
134        match self {
135            AggregationResponse::Bucket(agg) => {
136                let labels: Vec<String> = agg.buckets.iter().map(|b| b.key_to_string()).collect();
137                let values: Vec<f64> = agg.buckets.iter().map(|b| b.doc_count as f64).collect();
138                (labels, values)
139            }
140            _ => (vec![], vec![]),
141        }
142    }
143}
144
145/// TopHitsResult equivalent
146#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
147pub struct TopHitsResult {
148    pub total: i64,
149    #[serde(rename = "max_score")]
150    pub max_score: Option<f64>,
151    pub hits: Vec<Hit>,
152}
153
154impl Default for TopHitsResult {
155    fn default() -> Self {
156        Self {
157            total: 0,
158            max_score: None,
159            hits: vec![],
160        }
161    }
162}
163
164/// Simple aggregation
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
166pub struct Simple {
167    #[serde(rename = "_source")]
168    pub source_aggregation: Option<RequestAggregation>,
169    pub meta: Option<Value>,
170}
171
172/// Bucket structure
173#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
174pub struct Bucket {
175    pub key: Value,
176    #[serde(rename = "doc_count")]
177    pub doc_count: i64,
178    #[serde(rename = "bg_count")]
179    pub bg_count: Option<i64>,
180    pub score: Option<f64>,
181    #[serde(rename = "key_as_string")]
182    pub key_as_string: Option<String>,
183    #[serde(flatten)]
184    pub sub_aggs: Map<String, AggregationResponse>,
185}
186
187impl Bucket {
188    pub fn key_to_string(&self) -> String {
189        if let Some(ref key_as_string) = self.key_as_string {
190            key_as_string.clone()
191        } else {
192            match &self.key {
193                Value::String(s) => s.clone(),
194                _ => self.key.to_string(),
195            }
196        }
197    }
198}
199
200/// MultiBucketBucket structure
201#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
202pub struct MultiBucketBucket {
203    #[serde(rename = "doc_count")]
204    pub doc_count: i64,
205    pub buckets: Map<String, BucketAggregation>,
206}
207
208/// MultiBucketAggregation
209#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
210pub struct MultiBucketAggregation {
211    pub buckets: Map<String, MultiBucketBucket>,
212    #[serde(rename = "_source")]
213    pub source_aggregation: Option<RequestAggregation>,
214    pub meta: Option<Value>,
215}
216
217/// BucketAggregation
218#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
219pub struct BucketAggregation {
220    pub buckets: Vec<Bucket>,
221    #[serde(rename = "doc_count_error_upper_bound")]
222    pub doc_count_error_upper_bound: i64,
223    #[serde(rename = "sum_other_doc_count")]
224    pub sum_other_doc_count: i64,
225    #[serde(rename = "_source")]
226    pub source_aggregation: Option<RequestAggregation>,
227    pub meta: Option<Value>,
228}
229
230impl Default for BucketAggregation {
231    fn default() -> Self {
232        Self {
233            buckets: vec![],
234            doc_count_error_upper_bound: 0,
235            sum_other_doc_count: 0,
236            source_aggregation: None,
237            meta: None,
238        }
239    }
240}
241
242impl BucketAggregation {
243    pub fn buckets_count_as_list(&self) -> Vec<(String, i64)> {
244        self.buckets
245            .iter()
246            .map(|b| (b.key_to_string(), b.doc_count))
247            .collect()
248    }
249
250    pub fn buckets_count_as_map(&self) -> Map<String, i64> {
251        self.buckets
252            .iter()
253            .map(|b| (b.key_to_string(), b.doc_count))
254            .collect()
255    }
256}
257
258/// DocCountAggregation
259#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
260pub struct DocCountAggregation {
261    #[serde(rename = "doc_count")]
262    pub doc_count: f64,
263    #[serde(flatten)]
264    pub sub_aggs: Map<String, AggregationResponse>,
265    #[serde(rename = "_source")]
266    pub source_aggregation: Option<RequestAggregation>,
267    pub meta: Option<Value>,
268}
269
270/// GeoBoundsValue
271#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
272pub struct GeoBoundsValue {
273    #[serde(rename = "top_left")]
274    pub top_left: GeoLocation,
275    #[serde(rename = "bottom_right")]
276    pub bottom_right: GeoLocation,
277    #[serde(rename = "_source")]
278    pub source_aggregation: Option<RequestAggregation>,
279    pub meta: Option<Value>,
280}
281
282/// MetricExtendedStats
283#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
284pub struct MetricExtendedStats {
285    pub count: i64,
286    pub min: f64,
287    pub max: f64,
288    pub avg: f64,
289    pub sum: f64,
290    #[serde(rename = "sum_of_squares")]
291    pub sum_of_squares: f64,
292    pub variance: f64,
293    #[serde(rename = "std_deviation")]
294    pub std_deviation: f64,
295    #[serde(rename = "_source")]
296    pub source_aggregation: Option<RequestAggregation>,
297    pub meta: Option<Value>,
298}
299
300impl Default for MetricExtendedStats {
301    fn default() -> Self {
302        Self {
303            count: 0,
304            min: 0.0,
305            max: 0.0,
306            avg: 0.0,
307            sum: 0.0,
308            sum_of_squares: 0.0,
309            variance: 0.0,
310            std_deviation: 0.0,
311            source_aggregation: None,
312            meta: None,
313        }
314    }
315}
316
317/// TopHitsStats
318#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
319pub struct TopHitsStats {
320    pub hits: TopHitsResult,
321    #[serde(rename = "_source")]
322    pub source_aggregation: Option<RequestAggregation>,
323    pub meta: Option<Value>,
324}
325
326/// MetricStats
327#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
328pub struct MetricStats {
329    pub count: i64,
330    pub min: f64,
331    pub max: f64,
332    pub avg: f64,
333    pub sum: f64,
334    #[serde(rename = "_source")]
335    pub source_aggregation: Option<RequestAggregation>,
336    pub meta: Option<Value>,
337}
338
339impl Default for MetricStats {
340    fn default() -> Self {
341        Self {
342            count: 0,
343            min: 0.0,
344            max: 0.0,
345            avg: 0.0,
346            sum: 0.0,
347            source_aggregation: None,
348            meta: None,
349        }
350    }
351}
352
353/// MetricValue
354#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
355pub struct MetricValue {
356    pub value: f64,
357    #[serde(rename = "_source")]
358    pub source_aggregation: Option<RequestAggregation>,
359    pub meta: Option<Value>,
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use serde_json::json;
366
367    #[test]
368    fn test_metric_value_serialization() {
369        let metric = MetricValue {
370            value: 42.5,
371            source_aggregation: None,
372            meta: None,
373        };
374
375        let json = serde_json::to_value(&metric).unwrap();
376        assert_eq!(json["value"], 42.5);
377    }
378
379    #[test]
380    fn test_bucket_key_to_string() {
381        let bucket = Bucket {
382            key: json!("test_key"),
383            doc_count: 10,
384            bg_count: None,
385            score: None,
386            key_as_string: Some("test_key_string".to_string()),
387            sub_aggs: Map::new(),
388        };
389
390        assert_eq!(bucket.key_to_string(), "test_key_string");
391
392        let bucket2 = Bucket {
393            key: json!("another_key"),
394            doc_count: 5,
395            bg_count: None,
396            score: None,
397            key_as_string: None,
398            sub_aggs: Map::new(),
399        };
400
401        assert_eq!(bucket2.key_to_string(), "another_key");
402    }
403
404    #[test]
405    fn test_aggregation_trait() {
406        let bucket_agg = BucketAggregation {
407            buckets: vec![],
408            doc_count_error_upper_bound: 0,
409            sum_other_doc_count: 0,
410            source_aggregation: None,
411            meta: None,
412        };
413
414        let agg = AggregationResponse::Bucket(bucket_agg);
415        assert!(agg.is_empty());
416        assert!(!agg.non_empty());
417    }
418}