Skip to main content

oxigdal_stac/
aggregation.rs

1//! Aggregation support for STAC API.
2//!
3//! This module provides aggregation functions for STAC search results.
4
5use crate::error::{Result, StacError};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9
10/// Aggregation request for STAC API.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct AggregationRequest {
13    /// Aggregations to compute.
14    pub aggregations: Vec<Aggregation>,
15}
16
17/// Single aggregation specification.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum Aggregation {
21    /// Count aggregation.
22    #[serde(rename = "count")]
23    Count {
24        /// Name of the aggregation.
25        name: String,
26        /// Field to count (optional, counts all if not specified).
27        #[serde(skip_serializing_if = "Option::is_none")]
28        field: Option<String>,
29    },
30
31    /// Sum aggregation.
32    #[serde(rename = "sum")]
33    Sum {
34        /// Name of the aggregation.
35        name: String,
36        /// Field to sum.
37        field: String,
38    },
39
40    /// Average aggregation.
41    #[serde(rename = "avg")]
42    Avg {
43        /// Name of the aggregation.
44        name: String,
45        /// Field to average.
46        field: String,
47    },
48
49    /// Min aggregation.
50    #[serde(rename = "min")]
51    Min {
52        /// Name of the aggregation.
53        name: String,
54        /// Field to find minimum.
55        field: String,
56    },
57
58    /// Max aggregation.
59    #[serde(rename = "max")]
60    Max {
61        /// Name of the aggregation.
62        name: String,
63        /// Field to find maximum.
64        field: String,
65    },
66
67    /// Stats aggregation (count, sum, avg, min, max).
68    #[serde(rename = "stats")]
69    Stats {
70        /// Name of the aggregation.
71        name: String,
72        /// Field to compute statistics.
73        field: String,
74    },
75
76    /// Terms aggregation (frequency count by value).
77    #[serde(rename = "terms")]
78    Terms {
79        /// Name of the aggregation.
80        name: String,
81        /// Field to group by.
82        field: String,
83        /// Maximum number of buckets to return.
84        #[serde(skip_serializing_if = "Option::is_none")]
85        size: Option<u32>,
86    },
87
88    /// Histogram aggregation (bucketed numeric values).
89    #[serde(rename = "histogram")]
90    Histogram {
91        /// Name of the aggregation.
92        name: String,
93        /// Field to histogram.
94        field: String,
95        /// Interval for buckets.
96        interval: f64,
97        /// Minimum value.
98        #[serde(skip_serializing_if = "Option::is_none")]
99        min: Option<f64>,
100        /// Maximum value.
101        #[serde(skip_serializing_if = "Option::is_none")]
102        max: Option<f64>,
103    },
104
105    /// Date histogram aggregation (bucketed temporal values).
106    #[serde(rename = "date_histogram")]
107    DateHistogram {
108        /// Name of the aggregation.
109        name: String,
110        /// Field to histogram.
111        field: String,
112        /// Calendar interval (e.g., "1d", "1M", "1y").
113        interval: String,
114        /// Time zone (e.g., "UTC", "America/New_York").
115        #[serde(skip_serializing_if = "Option::is_none")]
116        time_zone: Option<String>,
117    },
118
119    /// Geohash grid aggregation (spatial bucketing).
120    #[serde(rename = "geohash_grid")]
121    GeohashGrid {
122        /// Name of the aggregation.
123        name: String,
124        /// Field containing geometries.
125        field: String,
126        /// Geohash precision (1-12).
127        precision: u8,
128    },
129}
130
131/// Aggregation response from STAC API.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct AggregationResponse {
134    /// Aggregation results keyed by aggregation name.
135    pub aggregations: HashMap<String, AggregationResult>,
136}
137
138/// Result of a single aggregation.
139#[derive(Debug, Clone, Serialize, Deserialize)]
140#[serde(untagged)]
141pub enum AggregationResult {
142    /// Simple numeric value.
143    Value(f64),
144
145    /// Statistics result.
146    Stats {
147        /// Count of values.
148        count: u64,
149        /// Sum of values.
150        sum: f64,
151        /// Average of values.
152        avg: f64,
153        /// Minimum value.
154        min: f64,
155        /// Maximum value.
156        max: f64,
157    },
158
159    /// Bucketed results.
160    Buckets {
161        /// List of buckets.
162        buckets: Vec<Bucket>,
163    },
164}
165
166/// Bucket in a terms or histogram aggregation.
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct Bucket {
169    /// Bucket key (value or range).
170    pub key: Value,
171    /// Count of items in this bucket.
172    pub doc_count: u64,
173    /// Sub-aggregations (if any).
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub aggregations: Option<HashMap<String, AggregationResult>>,
176}
177
178impl AggregationRequest {
179    /// Creates a new aggregation request.
180    pub fn new() -> Self {
181        Self {
182            aggregations: Vec::new(),
183        }
184    }
185
186    /// Adds an aggregation to the request.
187    pub fn add(mut self, aggregation: Aggregation) -> Self {
188        self.aggregations.push(aggregation);
189        self
190    }
191
192    /// Creates a count aggregation.
193    pub fn count(name: impl Into<String>) -> Aggregation {
194        Aggregation::Count {
195            name: name.into(),
196            field: None,
197        }
198    }
199
200    /// Creates a sum aggregation.
201    pub fn sum(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
202        Aggregation::Sum {
203            name: name.into(),
204            field: field.into(),
205        }
206    }
207
208    /// Creates an average aggregation.
209    pub fn avg(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
210        Aggregation::Avg {
211            name: name.into(),
212            field: field.into(),
213        }
214    }
215
216    /// Creates a min aggregation.
217    pub fn min(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
218        Aggregation::Min {
219            name: name.into(),
220            field: field.into(),
221        }
222    }
223
224    /// Creates a max aggregation.
225    pub fn max(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
226        Aggregation::Max {
227            name: name.into(),
228            field: field.into(),
229        }
230    }
231
232    /// Creates a stats aggregation.
233    pub fn stats(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
234        Aggregation::Stats {
235            name: name.into(),
236            field: field.into(),
237        }
238    }
239
240    /// Creates a terms aggregation.
241    pub fn terms(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
242        Aggregation::Terms {
243            name: name.into(),
244            field: field.into(),
245            size: None,
246        }
247    }
248
249    /// Creates a histogram aggregation.
250    pub fn histogram(
251        name: impl Into<String>,
252        field: impl Into<String>,
253        interval: f64,
254    ) -> Aggregation {
255        Aggregation::Histogram {
256            name: name.into(),
257            field: field.into(),
258            interval,
259            min: None,
260            max: None,
261        }
262    }
263
264    /// Creates a date histogram aggregation.
265    pub fn date_histogram(
266        name: impl Into<String>,
267        field: impl Into<String>,
268        interval: impl Into<String>,
269    ) -> Aggregation {
270        Aggregation::DateHistogram {
271            name: name.into(),
272            field: field.into(),
273            interval: interval.into(),
274            time_zone: None,
275        }
276    }
277
278    /// Creates a geohash grid aggregation.
279    pub fn geohash_grid(
280        name: impl Into<String>,
281        field: impl Into<String>,
282        precision: u8,
283    ) -> Result<Aggregation> {
284        if !(1..=12).contains(&precision) {
285            return Err(StacError::InvalidFieldValue {
286                field: "precision".to_string(),
287                reason: "must be between 1 and 12".to_string(),
288            });
289        }
290
291        Ok(Aggregation::GeohashGrid {
292            name: name.into(),
293            field: field.into(),
294            precision,
295        })
296    }
297}
298
299impl Default for AggregationRequest {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305impl AggregationResponse {
306    /// Gets a single value result by name.
307    pub fn get_value(&self, name: &str) -> Option<f64> {
308        self.aggregations.get(name).and_then(|result| {
309            if let AggregationResult::Value(v) = result {
310                Some(*v)
311            } else {
312                None
313            }
314        })
315    }
316
317    /// Gets stats result by name.
318    pub fn get_stats(&self, name: &str) -> Option<(u64, f64, f64, f64, f64)> {
319        self.aggregations.get(name).and_then(|result| {
320            if let AggregationResult::Stats {
321                count,
322                sum,
323                avg,
324                min,
325                max,
326            } = result
327            {
328                Some((*count, *sum, *avg, *min, *max))
329            } else {
330                None
331            }
332        })
333    }
334
335    /// Gets buckets result by name.
336    pub fn get_buckets(&self, name: &str) -> Option<&Vec<Bucket>> {
337        self.aggregations.get(name).and_then(|result| {
338            if let AggregationResult::Buckets { buckets } = result {
339                Some(buckets)
340            } else {
341                None
342            }
343        })
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_aggregation_request_builder() {
353        let request = AggregationRequest::new()
354            .add(AggregationRequest::count("total"))
355            .add(AggregationRequest::avg("avg_cloud_cover", "eo:cloud_cover"))
356            .add(AggregationRequest::terms("platforms", "platform"));
357
358        assert_eq!(request.aggregations.len(), 3);
359    }
360
361    #[test]
362    fn test_aggregation_serialization() {
363        let request = AggregationRequest::new()
364            .add(AggregationRequest::count("total"))
365            .add(AggregationRequest::stats(
366                "cloud_cover_stats",
367                "eo:cloud_cover",
368            ));
369
370        let json = serde_json::to_string(&request).expect("Failed to serialize");
371        assert!(json.contains("count"));
372        assert!(json.contains("stats"));
373    }
374
375    #[test]
376    fn test_geohash_grid_validation() {
377        let valid = AggregationRequest::geohash_grid("geo_grid", "geometry", 5);
378        assert!(valid.is_ok());
379
380        let invalid = AggregationRequest::geohash_grid("geo_grid", "geometry", 15);
381        assert!(invalid.is_err());
382    }
383
384    #[test]
385    fn test_aggregation_response() {
386        let mut aggregations = HashMap::new();
387        aggregations.insert("total".to_string(), AggregationResult::Value(1000.0));
388        aggregations.insert(
389            "stats".to_string(),
390            AggregationResult::Stats {
391                count: 100,
392                sum: 1500.0,
393                avg: 15.0,
394                min: 0.0,
395                max: 100.0,
396            },
397        );
398
399        let response = AggregationResponse { aggregations };
400
401        assert_eq!(response.get_value("total"), Some(1000.0));
402        assert_eq!(
403            response.get_stats("stats"),
404            Some((100, 1500.0, 15.0, 0.0, 100.0))
405        );
406    }
407}