1use crate::error::{Result, StacError};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct AggregationRequest {
13 pub aggregations: Vec<Aggregation>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type")]
20pub enum Aggregation {
21 #[serde(rename = "count")]
23 Count {
24 name: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
28 field: Option<String>,
29 },
30
31 #[serde(rename = "sum")]
33 Sum {
34 name: String,
36 field: String,
38 },
39
40 #[serde(rename = "avg")]
42 Avg {
43 name: String,
45 field: String,
47 },
48
49 #[serde(rename = "min")]
51 Min {
52 name: String,
54 field: String,
56 },
57
58 #[serde(rename = "max")]
60 Max {
61 name: String,
63 field: String,
65 },
66
67 #[serde(rename = "stats")]
69 Stats {
70 name: String,
72 field: String,
74 },
75
76 #[serde(rename = "terms")]
78 Terms {
79 name: String,
81 field: String,
83 #[serde(skip_serializing_if = "Option::is_none")]
85 size: Option<u32>,
86 },
87
88 #[serde(rename = "histogram")]
90 Histogram {
91 name: String,
93 field: String,
95 interval: f64,
97 #[serde(skip_serializing_if = "Option::is_none")]
99 min: Option<f64>,
100 #[serde(skip_serializing_if = "Option::is_none")]
102 max: Option<f64>,
103 },
104
105 #[serde(rename = "date_histogram")]
107 DateHistogram {
108 name: String,
110 field: String,
112 interval: String,
114 #[serde(skip_serializing_if = "Option::is_none")]
116 time_zone: Option<String>,
117 },
118
119 #[serde(rename = "geohash_grid")]
121 GeohashGrid {
122 name: String,
124 field: String,
126 precision: u8,
128 },
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct AggregationResponse {
134 pub aggregations: HashMap<String, AggregationResult>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140#[serde(untagged)]
141pub enum AggregationResult {
142 Value(f64),
144
145 Stats {
147 count: u64,
149 sum: f64,
151 avg: f64,
153 min: f64,
155 max: f64,
157 },
158
159 Buckets {
161 buckets: Vec<Bucket>,
163 },
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct Bucket {
169 pub key: Value,
171 pub doc_count: u64,
173 #[serde(skip_serializing_if = "Option::is_none")]
175 pub aggregations: Option<HashMap<String, AggregationResult>>,
176}
177
178impl AggregationRequest {
179 pub fn new() -> Self {
181 Self {
182 aggregations: Vec::new(),
183 }
184 }
185
186 pub fn add(mut self, aggregation: Aggregation) -> Self {
188 self.aggregations.push(aggregation);
189 self
190 }
191
192 pub fn count(name: impl Into<String>) -> Aggregation {
194 Aggregation::Count {
195 name: name.into(),
196 field: None,
197 }
198 }
199
200 pub fn sum(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
202 Aggregation::Sum {
203 name: name.into(),
204 field: field.into(),
205 }
206 }
207
208 pub fn avg(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
210 Aggregation::Avg {
211 name: name.into(),
212 field: field.into(),
213 }
214 }
215
216 pub fn min(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
218 Aggregation::Min {
219 name: name.into(),
220 field: field.into(),
221 }
222 }
223
224 pub fn max(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
226 Aggregation::Max {
227 name: name.into(),
228 field: field.into(),
229 }
230 }
231
232 pub fn stats(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
234 Aggregation::Stats {
235 name: name.into(),
236 field: field.into(),
237 }
238 }
239
240 pub fn terms(name: impl Into<String>, field: impl Into<String>) -> Aggregation {
242 Aggregation::Terms {
243 name: name.into(),
244 field: field.into(),
245 size: None,
246 }
247 }
248
249 pub fn histogram(
251 name: impl Into<String>,
252 field: impl Into<String>,
253 interval: f64,
254 ) -> Aggregation {
255 Aggregation::Histogram {
256 name: name.into(),
257 field: field.into(),
258 interval,
259 min: None,
260 max: None,
261 }
262 }
263
264 pub fn date_histogram(
266 name: impl Into<String>,
267 field: impl Into<String>,
268 interval: impl Into<String>,
269 ) -> Aggregation {
270 Aggregation::DateHistogram {
271 name: name.into(),
272 field: field.into(),
273 interval: interval.into(),
274 time_zone: None,
275 }
276 }
277
278 pub fn geohash_grid(
280 name: impl Into<String>,
281 field: impl Into<String>,
282 precision: u8,
283 ) -> Result<Aggregation> {
284 if !(1..=12).contains(&precision) {
285 return Err(StacError::InvalidFieldValue {
286 field: "precision".to_string(),
287 reason: "must be between 1 and 12".to_string(),
288 });
289 }
290
291 Ok(Aggregation::GeohashGrid {
292 name: name.into(),
293 field: field.into(),
294 precision,
295 })
296 }
297}
298
299impl Default for AggregationRequest {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305impl AggregationResponse {
306 pub fn get_value(&self, name: &str) -> Option<f64> {
308 self.aggregations.get(name).and_then(|result| {
309 if let AggregationResult::Value(v) = result {
310 Some(*v)
311 } else {
312 None
313 }
314 })
315 }
316
317 pub fn get_stats(&self, name: &str) -> Option<(u64, f64, f64, f64, f64)> {
319 self.aggregations.get(name).and_then(|result| {
320 if let AggregationResult::Stats {
321 count,
322 sum,
323 avg,
324 min,
325 max,
326 } = result
327 {
328 Some((*count, *sum, *avg, *min, *max))
329 } else {
330 None
331 }
332 })
333 }
334
335 pub fn get_buckets(&self, name: &str) -> Option<&Vec<Bucket>> {
337 self.aggregations.get(name).and_then(|result| {
338 if let AggregationResult::Buckets { buckets } = result {
339 Some(buckets)
340 } else {
341 None
342 }
343 })
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_aggregation_request_builder() {
353 let request = AggregationRequest::new()
354 .add(AggregationRequest::count("total"))
355 .add(AggregationRequest::avg("avg_cloud_cover", "eo:cloud_cover"))
356 .add(AggregationRequest::terms("platforms", "platform"));
357
358 assert_eq!(request.aggregations.len(), 3);
359 }
360
361 #[test]
362 fn test_aggregation_serialization() {
363 let request = AggregationRequest::new()
364 .add(AggregationRequest::count("total"))
365 .add(AggregationRequest::stats(
366 "cloud_cover_stats",
367 "eo:cloud_cover",
368 ));
369
370 let json = serde_json::to_string(&request).expect("Failed to serialize");
371 assert!(json.contains("count"));
372 assert!(json.contains("stats"));
373 }
374
375 #[test]
376 fn test_geohash_grid_validation() {
377 let valid = AggregationRequest::geohash_grid("geo_grid", "geometry", 5);
378 assert!(valid.is_ok());
379
380 let invalid = AggregationRequest::geohash_grid("geo_grid", "geometry", 15);
381 assert!(invalid.is_err());
382 }
383
384 #[test]
385 fn test_aggregation_response() {
386 let mut aggregations = HashMap::new();
387 aggregations.insert("total".to_string(), AggregationResult::Value(1000.0));
388 aggregations.insert(
389 "stats".to_string(),
390 AggregationResult::Stats {
391 count: 100,
392 sum: 1500.0,
393 avg: 15.0,
394 min: 0.0,
395 max: 100.0,
396 },
397 );
398
399 let response = AggregationResponse { aggregations };
400
401 assert_eq!(response.get_value("total"), Some(1000.0));
402 assert_eq!(
403 response.get_stats("stats"),
404 Some((100, 1500.0, 15.0, 0.0, 100.0))
405 );
406 }
407}