1use crate::error::{Result, TemporalError};
7use crate::timeseries::{TemporalMetadata, TimeSeriesRaster};
8use chrono::{DateTime, Datelike, Duration, Utc};
9use scirs2_core::ndarray::Array3;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tracing::{debug, info};
13
14type TemporalGroups<'a> = HashMap<String, Vec<(DateTime<Utc>, &'a Array3<f64>)>>;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19pub enum TemporalWindow {
20 Daily,
22 Weekly,
24 Monthly,
26 Yearly,
28 CustomDays(i64),
30 Rolling(usize),
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum AggregationStatistic {
37 Mean,
39 Median,
41 Min,
43 Max,
45 Sum,
47 StdDev,
49 Count,
51 First,
53 Last,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct AggregationConfig {
60 pub window: TemporalWindow,
62 pub statistics: Vec<AggregationStatistic>,
64 pub nodata: Option<f64>,
66 pub min_observations: usize,
68}
69
70impl Default for AggregationConfig {
71 fn default() -> Self {
72 Self {
73 window: TemporalWindow::Monthly,
74 statistics: vec![AggregationStatistic::Mean],
75 nodata: Some(f64::NAN),
76 min_observations: 1,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct AggregationResult {
84 pub time_series: HashMap<String, TimeSeriesRaster>,
86 pub window_starts: Vec<DateTime<Utc>>,
88 pub window_ends: Vec<DateTime<Utc>>,
90}
91
92impl AggregationResult {
93 #[must_use]
95 pub fn new() -> Self {
96 Self {
97 time_series: HashMap::new(),
98 window_starts: Vec::new(),
99 window_ends: Vec::new(),
100 }
101 }
102
103 #[must_use]
105 pub fn get(&self, stat: &str) -> Option<&TimeSeriesRaster> {
106 self.time_series.get(stat)
107 }
108
109 pub fn add(&mut self, stat: String, ts: TimeSeriesRaster) {
111 self.time_series.insert(stat, ts);
112 }
113}
114
115impl Default for AggregationResult {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121pub struct TemporalAggregator;
123
124impl TemporalAggregator {
125 pub fn aggregate(
130 ts: &TimeSeriesRaster,
131 config: &AggregationConfig,
132 ) -> Result<AggregationResult> {
133 match config.window {
134 TemporalWindow::Daily => Self::aggregate_daily(ts, config),
135 TemporalWindow::Weekly => Self::aggregate_weekly(ts, config),
136 TemporalWindow::Monthly => Self::aggregate_monthly(ts, config),
137 TemporalWindow::Yearly => Self::aggregate_yearly(ts, config),
138 TemporalWindow::CustomDays(days) => Self::aggregate_custom_days(ts, config, days),
139 TemporalWindow::Rolling(size) => Self::aggregate_rolling(ts, config, size),
140 }
141 }
142
143 fn aggregate_daily(
145 ts: &TimeSeriesRaster,
146 config: &AggregationConfig,
147 ) -> Result<AggregationResult> {
148 let (height, width, n_bands) = ts
149 .expected_shape()
150 .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
151
152 let mut daily_groups: TemporalGroups = HashMap::new();
154
155 for (_, entry) in ts.iter() {
156 let date_key = entry.metadata.acquisition_date.to_string();
157 if let Some(data) = entry.data.as_ref() {
158 daily_groups
159 .entry(date_key)
160 .or_default()
161 .push((entry.metadata.timestamp, data));
162 }
163 }
164
165 Self::compute_statistics(&daily_groups, config, height, width, n_bands)
166 }
167
168 fn aggregate_weekly(
170 ts: &TimeSeriesRaster,
171 config: &AggregationConfig,
172 ) -> Result<AggregationResult> {
173 let (height, width, n_bands) = ts
174 .expected_shape()
175 .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
176
177 let mut weekly_groups: TemporalGroups = HashMap::new();
179
180 for (_, entry) in ts.iter() {
181 let year = entry.metadata.acquisition_date.year();
182 let week = entry.metadata.acquisition_date.iso_week().week();
183 let week_key = format!("{}-W{:02}", year, week);
184
185 if let Some(data) = entry.data.as_ref() {
186 weekly_groups
187 .entry(week_key)
188 .or_default()
189 .push((entry.metadata.timestamp, data));
190 }
191 }
192
193 Self::compute_statistics(&weekly_groups, config, height, width, n_bands)
194 }
195
196 fn aggregate_monthly(
198 ts: &TimeSeriesRaster,
199 config: &AggregationConfig,
200 ) -> Result<AggregationResult> {
201 let (height, width, n_bands) = ts
202 .expected_shape()
203 .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
204
205 let mut monthly_groups: TemporalGroups = HashMap::new();
207
208 for (_, entry) in ts.iter() {
209 let year = entry.metadata.acquisition_date.year();
210 let month = entry.metadata.acquisition_date.month();
211 let month_key = format!("{}-{:02}", year, month);
212
213 if let Some(data) = entry.data.as_ref() {
214 monthly_groups
215 .entry(month_key)
216 .or_default()
217 .push((entry.metadata.timestamp, data));
218 }
219 }
220
221 Self::compute_statistics(&monthly_groups, config, height, width, n_bands)
222 }
223
224 fn aggregate_yearly(
226 ts: &TimeSeriesRaster,
227 config: &AggregationConfig,
228 ) -> Result<AggregationResult> {
229 let (height, width, n_bands) = ts
230 .expected_shape()
231 .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
232
233 let mut yearly_groups: TemporalGroups = HashMap::new();
235
236 for (_, entry) in ts.iter() {
237 let year = entry.metadata.acquisition_date.year();
238 let year_key = format!("{}", year);
239
240 if let Some(data) = entry.data.as_ref() {
241 yearly_groups
242 .entry(year_key)
243 .or_default()
244 .push((entry.metadata.timestamp, data));
245 }
246 }
247
248 Self::compute_statistics(&yearly_groups, config, height, width, n_bands)
249 }
250
251 fn aggregate_custom_days(
253 ts: &TimeSeriesRaster,
254 config: &AggregationConfig,
255 days: i64,
256 ) -> Result<AggregationResult> {
257 if days <= 0 {
258 return Err(TemporalError::invalid_parameter("days", "must be positive"));
259 }
260
261 let (height, width, n_bands) = ts
262 .expected_shape()
263 .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
264
265 let (start_time, end_time) = ts
266 .time_range()
267 .ok_or_else(|| TemporalError::insufficient_data("Empty time series"))?;
268
269 let mut window_groups: TemporalGroups = HashMap::new();
271 let mut current = start_time;
272
273 while current < end_time {
274 let next = current + Duration::days(days);
275 let window_key = format!("{}", current.format("%Y-%m-%d"));
276
277 let entries = ts.query_range(¤t, &next);
279 for entry in entries {
280 if let Some(data) = entry.data.as_ref() {
281 window_groups
282 .entry(window_key.clone())
283 .or_default()
284 .push((entry.metadata.timestamp, data));
285 }
286 }
287
288 current = next;
289 }
290
291 Self::compute_statistics(&window_groups, config, height, width, n_bands)
292 }
293
294 fn aggregate_rolling(
296 ts: &TimeSeriesRaster,
297 config: &AggregationConfig,
298 window_size: usize,
299 ) -> Result<AggregationResult> {
300 if window_size == 0 {
301 return Err(TemporalError::invalid_parameter(
302 "window_size",
303 "must be greater than 0",
304 ));
305 }
306
307 if window_size > ts.len() {
308 return Err(TemporalError::invalid_parameter(
309 "window_size",
310 "exceeds time series length",
311 ));
312 }
313
314 let (height, width, n_bands) = ts
315 .expected_shape()
316 .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
317
318 let entries: Vec<_> = ts.iter().collect();
319 let mut result = AggregationResult::new();
320
321 for stat in &config.statistics {
323 result.add(format!("{:?}", stat), TimeSeriesRaster::new());
324 }
325
326 for i in 0..=(entries.len().saturating_sub(window_size)) {
328 let window_entries: Vec<_> = entries[i..i + window_size]
329 .iter()
330 .filter_map(|(_, e)| e.data.as_ref().map(|d| (e.metadata.timestamp, d)))
331 .collect();
332
333 if window_entries.len() < config.min_observations {
334 continue;
335 }
336
337 let mid_timestamp = window_entries[window_entries.len() / 2].0;
339
340 for stat in &config.statistics {
342 let aggregated = Self::compute_statistic(
343 &window_entries,
344 *stat,
345 config,
346 height,
347 width,
348 n_bands,
349 )?;
350
351 let metadata = TemporalMetadata::new(mid_timestamp, mid_timestamp.date_naive());
352 let stat_key = format!("{:?}", stat);
353 if let Some(ts) = result.time_series.get_mut(&stat_key) {
354 ts.add_raster(metadata, aggregated)?;
355 }
356 }
357 }
358
359 info!(
360 "Completed rolling aggregation with window size {}",
361 window_size
362 );
363 Ok(result)
364 }
365
366 fn compute_statistics(
368 groups: &TemporalGroups,
369 config: &AggregationConfig,
370 height: usize,
371 width: usize,
372 n_bands: usize,
373 ) -> Result<AggregationResult> {
374 let mut result = AggregationResult::new();
375
376 for stat in &config.statistics {
378 result.add(format!("{:?}", stat), TimeSeriesRaster::new());
379 }
380
381 let mut sorted_keys: Vec<_> = groups.keys().collect();
383 sorted_keys.sort();
384
385 for key in sorted_keys {
386 let group = &groups[key];
387
388 if group.len() < config.min_observations {
389 debug!(
390 "Skipping group {} with {} observations (min: {})",
391 key,
392 group.len(),
393 config.min_observations
394 );
395 continue;
396 }
397
398 let timestamp = group[0].0;
400
401 for stat in &config.statistics {
403 let aggregated =
404 Self::compute_statistic(group, *stat, config, height, width, n_bands)?;
405
406 let metadata = TemporalMetadata::new(timestamp, timestamp.date_naive());
407 let stat_key = format!("{:?}", stat);
408 if let Some(ts) = result.time_series.get_mut(&stat_key) {
409 ts.add_raster(metadata, aggregated)?;
410 }
411 }
412 }
413
414 info!(
415 "Aggregated {} groups with {} statistics",
416 groups.len(),
417 config.statistics.len()
418 );
419 Ok(result)
420 }
421
422 fn compute_statistic(
424 entries: &[(DateTime<Utc>, &Array3<f64>)],
425 stat: AggregationStatistic,
426 config: &AggregationConfig,
427 height: usize,
428 width: usize,
429 n_bands: usize,
430 ) -> Result<Array3<f64>> {
431 let mut result: Array3<f64> = Array3::zeros((height, width, n_bands));
432
433 match stat {
434 AggregationStatistic::Mean => {
435 let mut sum: Array3<f64> = Array3::zeros((height, width, n_bands));
436 let mut count: Array3<f64> = Array3::zeros((height, width, n_bands));
437
438 for (_, data) in entries {
439 for i in 0..height {
440 for j in 0..width {
441 for k in 0..n_bands {
442 let val = data[[i, j, k]];
443 if Self::is_valid(val, config.nodata) {
444 sum[[i, j, k]] += val;
445 count[[i, j, k]] += 1.0;
446 }
447 }
448 }
449 }
450 }
451
452 for i in 0..height {
453 for j in 0..width {
454 for k in 0..n_bands {
455 result[[i, j, k]] = if count[[i, j, k]] > 0.0 {
456 sum[[i, j, k]] / count[[i, j, k]]
457 } else {
458 config.nodata.unwrap_or(f64::NAN)
459 };
460 }
461 }
462 }
463 }
464 AggregationStatistic::Min => {
465 result.fill(f64::INFINITY);
466 for (_, data) in entries {
467 for i in 0..height {
468 for j in 0..width {
469 for k in 0..n_bands {
470 let val = data[[i, j, k]];
471 if Self::is_valid(val, config.nodata) && val < result[[i, j, k]] {
472 result[[i, j, k]] = val;
473 }
474 }
475 }
476 }
477 }
478 for i in 0..height {
480 for j in 0..width {
481 for k in 0..n_bands {
482 if result[[i, j, k]].is_infinite() {
483 result[[i, j, k]] = config.nodata.unwrap_or(f64::NAN);
484 }
485 }
486 }
487 }
488 }
489 AggregationStatistic::Max => {
490 result.fill(f64::NEG_INFINITY);
491 for (_, data) in entries {
492 for i in 0..height {
493 for j in 0..width {
494 for k in 0..n_bands {
495 let val = data[[i, j, k]];
496 if Self::is_valid(val, config.nodata) && val > result[[i, j, k]] {
497 result[[i, j, k]] = val;
498 }
499 }
500 }
501 }
502 }
503 for i in 0..height {
505 for j in 0..width {
506 for k in 0..n_bands {
507 if result[[i, j, k]].is_infinite() {
508 result[[i, j, k]] = config.nodata.unwrap_or(f64::NAN);
509 }
510 }
511 }
512 }
513 }
514 AggregationStatistic::Sum => {
515 for (_, data) in entries {
516 for i in 0..height {
517 for j in 0..width {
518 for k in 0..n_bands {
519 let val = data[[i, j, k]];
520 if Self::is_valid(val, config.nodata) {
521 result[[i, j, k]] += val;
522 }
523 }
524 }
525 }
526 }
527 }
528 AggregationStatistic::Count => {
529 for (_, data) in entries {
530 for i in 0..height {
531 for j in 0..width {
532 for k in 0..n_bands {
533 let val = data[[i, j, k]];
534 if Self::is_valid(val, config.nodata) {
535 result[[i, j, k]] += 1.0;
536 }
537 }
538 }
539 }
540 }
541 }
542 AggregationStatistic::First => {
543 if let Some((_, first_data)) = entries.first() {
544 result = (*first_data).clone();
545 }
546 }
547 AggregationStatistic::Last => {
548 if let Some((_, last_data)) = entries.last() {
549 result = (*last_data).clone();
550 }
551 }
552 AggregationStatistic::StdDev => {
553 let mut mean: Array3<f64> = Array3::zeros((height, width, n_bands));
555 let mut count: Array3<f64> = Array3::zeros((height, width, n_bands));
556
557 for (_, data) in entries {
558 for i in 0..height {
559 for j in 0..width {
560 for k in 0..n_bands {
561 let val = data[[i, j, k]];
562 if Self::is_valid(val, config.nodata) {
563 mean[[i, j, k]] += val;
564 count[[i, j, k]] += 1.0;
565 }
566 }
567 }
568 }
569 }
570
571 for i in 0..height {
572 for j in 0..width {
573 for k in 0..n_bands {
574 if count[[i, j, k]] > 0.0 {
575 mean[[i, j, k]] /= count[[i, j, k]];
576 }
577 }
578 }
579 }
580
581 let mut variance: Array3<f64> = Array3::zeros((height, width, n_bands));
583
584 for (_, data) in entries {
585 for i in 0..height {
586 for j in 0..width {
587 for k in 0..n_bands {
588 let val = data[[i, j, k]];
589 if Self::is_valid(val, config.nodata) {
590 let diff = val - mean[[i, j, k]];
591 variance[[i, j, k]] += diff * diff;
592 }
593 }
594 }
595 }
596 }
597
598 for i in 0..height {
599 for j in 0..width {
600 for k in 0..n_bands {
601 result[[i, j, k]] = if count[[i, j, k]] > 1.0 {
602 (variance[[i, j, k]] / count[[i, j, k]]).sqrt()
603 } else {
604 config.nodata.unwrap_or(f64::NAN)
605 };
606 }
607 }
608 }
609 }
610 AggregationStatistic::Median => {
611 let mut pixel_values: Vec<Vec<Vec<Vec<f64>>>> =
613 vec![vec![vec![Vec::new(); n_bands]; width]; height];
614
615 for (_, data) in entries {
616 for i in 0..height {
617 for j in 0..width {
618 for k in 0..n_bands {
619 let val = data[[i, j, k]];
620 if Self::is_valid(val, config.nodata) {
621 pixel_values[i][j][k].push(val);
622 }
623 }
624 }
625 }
626 }
627
628 for i in 0..height {
630 for j in 0..width {
631 for k in 0..n_bands {
632 let values = &mut pixel_values[i][j][k];
633 if values.is_empty() {
634 result[[i, j, k]] = config.nodata.unwrap_or(f64::NAN);
635 } else {
636 values.sort_by(|a, b| {
637 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
638 });
639 let mid = values.len() / 2;
640 result[[i, j, k]] = if values.len() % 2 == 0 {
641 (values[mid - 1] + values[mid]) / 2.0
642 } else {
643 values[mid]
644 };
645 }
646 }
647 }
648 }
649 }
650 }
651
652 Ok(result)
653 }
654
655 fn is_valid(val: f64, nodata: Option<f64>) -> bool {
657 if let Some(nd) = nodata {
658 if nd.is_nan() {
659 !val.is_nan()
660 } else {
661 (val - nd).abs() > 1e-10
662 }
663 } else {
664 true
665 }
666 }
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672 use crate::timeseries::TemporalMetadata;
673 use chrono::NaiveDate;
674
675 fn create_test_timeseries() -> TimeSeriesRaster {
676 let mut ts = TimeSeriesRaster::new();
677
678 for i in 0..30 {
679 let dt = DateTime::from_timestamp(1640995200 + i * 86400, 0).expect("valid");
680 let date = NaiveDate::from_ymd_opt(2022, 1, 1 + i as u32).expect("valid");
681 let metadata = TemporalMetadata::new(dt, date);
682 let data = Array3::from_elem((5, 5, 2), i as f64);
683 ts.add_raster(metadata, data).expect("should add");
684 }
685
686 ts
687 }
688
689 #[test]
690 fn test_daily_aggregation() {
691 let ts = create_test_timeseries();
692 let config = AggregationConfig {
693 window: TemporalWindow::Daily,
694 statistics: vec![AggregationStatistic::Mean],
695 ..Default::default()
696 };
697
698 let result = TemporalAggregator::aggregate(&ts, &config).expect("should aggregate");
699 assert!(result.get("Mean").is_some());
700 }
701
702 #[test]
703 fn test_weekly_aggregation() {
704 let ts = create_test_timeseries();
705 let config = AggregationConfig {
706 window: TemporalWindow::Weekly,
707 statistics: vec![AggregationStatistic::Mean, AggregationStatistic::Max],
708 ..Default::default()
709 };
710
711 let result = TemporalAggregator::aggregate(&ts, &config).expect("should aggregate");
712 assert!(result.get("Mean").is_some());
713 assert!(result.get("Max").is_some());
714 }
715
716 #[test]
717 fn test_monthly_aggregation() {
718 let ts = create_test_timeseries();
719 let config = AggregationConfig {
720 window: TemporalWindow::Monthly,
721 statistics: vec![AggregationStatistic::Mean],
722 ..Default::default()
723 };
724
725 let result = TemporalAggregator::aggregate(&ts, &config).expect("should aggregate");
726 let mean_ts = result.get("Mean").expect("should have mean");
727 assert!(!mean_ts.is_empty());
728 }
729
730 #[test]
731 fn test_rolling_aggregation() {
732 let ts = create_test_timeseries();
733 let config = AggregationConfig {
734 window: TemporalWindow::Rolling(7),
735 statistics: vec![AggregationStatistic::Mean],
736 min_observations: 5,
737 ..Default::default()
738 };
739
740 let result = TemporalAggregator::aggregate(&ts, &config).expect("should aggregate");
741 let mean_ts = result.get("Mean").expect("should have mean");
742 assert!(!mean_ts.is_empty());
743 }
744
745 #[test]
746 fn test_multiple_statistics() {
747 let ts = create_test_timeseries();
748 let config = AggregationConfig {
749 window: TemporalWindow::Weekly,
750 statistics: vec![
751 AggregationStatistic::Mean,
752 AggregationStatistic::Min,
753 AggregationStatistic::Max,
754 AggregationStatistic::StdDev,
755 ],
756 ..Default::default()
757 };
758
759 let result = TemporalAggregator::aggregate(&ts, &config).expect("should aggregate");
760 assert!(result.get("Mean").is_some());
761 assert!(result.get("Min").is_some());
762 assert!(result.get("Max").is_some());
763 assert!(result.get("StdDev").is_some());
764 }
765}