Skip to main content

oxigdal_analytics/zonal/
stats.rs

1//! Advanced Zonal Statistics
2//!
3//! Calculate statistics for regions defined by zone masks.
4
5use crate::error::{AnalyticsError, Result};
6use scirs2_core::ndarray::{ArrayView2, ArrayView3};
7use std::collections::HashMap;
8
9/// Zonal statistics to calculate
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum ZonalStatistic {
12    /// Mean value
13    Mean,
14    /// Median value
15    Median,
16    /// Minimum value
17    Min,
18    /// Maximum value
19    Max,
20    /// Sum of values
21    Sum,
22    /// Count of pixels
23    Count,
24    /// Standard deviation
25    StdDev,
26    /// Variance
27    Variance,
28    /// Coefficient of variation
29    CoeffVar,
30    /// Percentile (requires parameter)
31    Percentile(u8),
32}
33
34/// Zonal statistics result
35#[derive(Debug, Clone)]
36pub struct ZonalResult {
37    /// Statistics per zone
38    pub zones: HashMap<i32, HashMap<ZonalStatistic, f64>>,
39    /// Zone IDs
40    pub zone_ids: Vec<i32>,
41}
42
43/// Zonal statistics calculator
44pub struct ZonalCalculator {
45    statistics: Vec<ZonalStatistic>,
46    no_data_value: Option<f64>,
47}
48
49impl ZonalCalculator {
50    /// Create a new zonal calculator
51    pub fn new() -> Self {
52        Self {
53            statistics: vec![
54                ZonalStatistic::Mean,
55                ZonalStatistic::Min,
56                ZonalStatistic::Max,
57                ZonalStatistic::Count,
58            ],
59            no_data_value: None,
60        }
61    }
62
63    /// Set statistics to calculate
64    pub fn with_statistics(mut self, stats: Vec<ZonalStatistic>) -> Self {
65        self.statistics = stats;
66        self
67    }
68
69    /// Set no-data value
70    pub fn with_no_data(mut self, value: f64) -> Self {
71        self.no_data_value = Some(value);
72        self
73    }
74
75    /// Calculate zonal statistics
76    ///
77    /// # Arguments
78    /// * `values` - Value raster (height × width)
79    /// * `zones` - Zone raster with integer zone IDs (height × width)
80    ///
81    /// # Errors
82    /// Returns error if dimensions don't match
83    pub fn calculate(
84        &self,
85        values: &ArrayView2<f64>,
86        zones: &ArrayView2<i32>,
87    ) -> Result<ZonalResult> {
88        if values.dim() != zones.dim() {
89            return Err(AnalyticsError::dimension_mismatch(
90                format!("{:?}", values.dim()),
91                format!("{:?}", zones.dim()),
92            ));
93        }
94
95        // Group values by zone
96        let mut zone_values: HashMap<i32, Vec<f64>> = HashMap::new();
97
98        for ((i, j), &zone_id) in zones.indexed_iter() {
99            let value = values[[i, j]];
100
101            // Skip no-data values
102            if let Some(no_data) = self.no_data_value {
103                if (value - no_data).abs() < f64::EPSILON {
104                    continue;
105                }
106            }
107
108            zone_values.entry(zone_id).or_default().push(value);
109        }
110
111        // Calculate statistics for each zone
112        let mut result_zones = HashMap::new();
113        let mut zone_ids: Vec<i32> = zone_values.keys().copied().collect();
114        zone_ids.sort_unstable();
115
116        for (&zone_id, values_in_zone) in &zone_values {
117            let mut stats = HashMap::new();
118
119            for &statistic in &self.statistics {
120                let value = self.calculate_statistic(statistic, values_in_zone)?;
121                stats.insert(statistic, value);
122            }
123
124            result_zones.insert(zone_id, stats);
125        }
126
127        Ok(ZonalResult {
128            zones: result_zones,
129            zone_ids,
130        })
131    }
132
133    /// Calculate multi-band zonal statistics
134    ///
135    /// # Arguments
136    /// * `values` - Multi-band value raster (height × width × bands)
137    /// * `zones` - Zone raster (height × width)
138    ///
139    /// # Errors
140    /// Returns error if dimensions don't match
141    pub fn calculate_multiband(
142        &self,
143        values: &ArrayView3<f64>,
144        zones: &ArrayView2<i32>,
145    ) -> Result<Vec<ZonalResult>> {
146        let (height, width, n_bands) = values.dim();
147
148        if (height, width) != zones.dim() {
149            return Err(AnalyticsError::dimension_mismatch(
150                format!("{}x{}", height, width),
151                format!("{:?}", zones.dim()),
152            ));
153        }
154
155        let mut results = Vec::with_capacity(n_bands);
156
157        for band in 0..n_bands {
158            let band_values = values.slice(s![.., .., band]);
159            let result = self.calculate(&band_values, zones)?;
160            results.push(result);
161        }
162
163        Ok(results)
164    }
165
166    /// Calculate a single statistic
167    fn calculate_statistic(&self, stat: ZonalStatistic, values: &[f64]) -> Result<f64> {
168        if values.is_empty() {
169            return Ok(f64::NAN);
170        }
171
172        match stat {
173            ZonalStatistic::Mean => Ok(values.iter().sum::<f64>() / values.len() as f64),
174            ZonalStatistic::Median => self.calculate_median(values),
175            ZonalStatistic::Min => values
176                .iter()
177                .copied()
178                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
179                .ok_or_else(|| AnalyticsError::zonal_stats_error("Failed to compute min")),
180            ZonalStatistic::Max => values
181                .iter()
182                .copied()
183                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
184                .ok_or_else(|| AnalyticsError::zonal_stats_error("Failed to compute max")),
185            ZonalStatistic::Sum => Ok(values.iter().sum()),
186            ZonalStatistic::Count => Ok(values.len() as f64),
187            ZonalStatistic::StdDev => self.calculate_std_dev(values),
188            ZonalStatistic::Variance => self.calculate_variance(values),
189            ZonalStatistic::CoeffVar => {
190                let mean = values.iter().sum::<f64>() / values.len() as f64;
191                let std_dev = self.calculate_std_dev(values)?;
192                Ok(if mean.abs() > f64::EPSILON {
193                    (std_dev / mean) * 100.0
194                } else {
195                    f64::NAN
196                })
197            }
198            ZonalStatistic::Percentile(p) => self.calculate_percentile(values, p),
199        }
200    }
201
202    fn calculate_median(&self, values: &[f64]) -> Result<f64> {
203        let mut sorted = values.to_vec();
204        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
205
206        let n = sorted.len();
207        if n % 2 == 0 {
208            Ok((sorted[n / 2 - 1] + sorted[n / 2]) / 2.0)
209        } else {
210            Ok(sorted[n / 2])
211        }
212    }
213
214    fn calculate_variance(&self, values: &[f64]) -> Result<f64> {
215        let n = values.len() as f64;
216        let mean = values.iter().sum::<f64>() / n;
217        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
218        Ok(variance)
219    }
220
221    fn calculate_std_dev(&self, values: &[f64]) -> Result<f64> {
222        Ok(self.calculate_variance(values)?.sqrt())
223    }
224
225    fn calculate_percentile(&self, values: &[f64], percentile: u8) -> Result<f64> {
226        if percentile > 100 {
227            return Err(AnalyticsError::invalid_parameter(
228                "percentile",
229                "must be between 0 and 100",
230            ));
231        }
232
233        let mut sorted = values.to_vec();
234        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
235
236        let n = sorted.len();
237        let rank = (percentile as f64 / 100.0) * ((n - 1) as f64);
238        let lower_idx = rank.floor() as usize;
239        let upper_idx = rank.ceil() as usize;
240        let fraction = rank - (lower_idx as f64);
241
242        Ok(sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]))
243    }
244}
245
246impl Default for ZonalCalculator {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252/// Weighted zonal statistics calculator
253pub struct WeightedZonalCalculator {
254    calculator: ZonalCalculator,
255}
256
257impl WeightedZonalCalculator {
258    /// Create a new weighted zonal calculator
259    pub fn new() -> Self {
260        Self {
261            calculator: ZonalCalculator::new(),
262        }
263    }
264
265    /// Set statistics to calculate
266    pub fn with_statistics(mut self, stats: Vec<ZonalStatistic>) -> Self {
267        self.calculator = self.calculator.with_statistics(stats);
268        self
269    }
270
271    /// Calculate weighted zonal statistics
272    ///
273    /// # Arguments
274    /// * `values` - Value raster
275    /// * `weights` - Weight raster (same dimensions as values)
276    /// * `zones` - Zone raster
277    ///
278    /// # Errors
279    /// Returns error if dimensions don't match
280    pub fn calculate(
281        &self,
282        values: &ArrayView2<f64>,
283        weights: &ArrayView2<f64>,
284        zones: &ArrayView2<i32>,
285    ) -> Result<ZonalResult> {
286        if values.dim() != weights.dim() || values.dim() != zones.dim() {
287            return Err(AnalyticsError::dimension_mismatch(
288                format!("{:?}", values.dim()),
289                "all inputs must have same dimensions".to_string(),
290            ));
291        }
292
293        // Group weighted values by zone
294        let mut zone_data: HashMap<i32, (Vec<f64>, Vec<f64>)> = HashMap::new();
295
296        for ((i, j), &zone_id) in zones.indexed_iter() {
297            let value = values[[i, j]];
298            let weight = weights[[i, j]];
299
300            if weight > 0.0 {
301                let entry = zone_data
302                    .entry(zone_id)
303                    .or_insert_with(|| (Vec::new(), Vec::new()));
304                entry.0.push(value);
305                entry.1.push(weight);
306            }
307        }
308
309        // Calculate weighted statistics
310        let mut result_zones = HashMap::new();
311        let mut zone_ids: Vec<i32> = zone_data.keys().copied().collect();
312        zone_ids.sort_unstable();
313
314        for (&zone_id, (values_in_zone, weights_in_zone)) in &zone_data {
315            let mut stats = HashMap::new();
316
317            // Weighted mean
318            let weighted_sum: f64 = values_in_zone
319                .iter()
320                .zip(weights_in_zone.iter())
321                .map(|(v, w)| v * w)
322                .sum();
323            let weight_sum: f64 = weights_in_zone.iter().sum();
324
325            if weight_sum > f64::EPSILON {
326                stats.insert(ZonalStatistic::Mean, weighted_sum / weight_sum);
327            }
328
329            // Count (unweighted)
330            stats.insert(ZonalStatistic::Count, values_in_zone.len() as f64);
331
332            // Min/Max (unweighted)
333            if let Some(&min) = values_in_zone
334                .iter()
335                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
336            {
337                stats.insert(ZonalStatistic::Min, min);
338            }
339
340            if let Some(&max) = values_in_zone
341                .iter()
342                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
343            {
344                stats.insert(ZonalStatistic::Max, max);
345            }
346
347            result_zones.insert(zone_id, stats);
348        }
349
350        Ok(ZonalResult {
351            zones: result_zones,
352            zone_ids,
353        })
354    }
355}
356
357impl Default for WeightedZonalCalculator {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363// Import ndarray slice macro
364use scirs2_core::ndarray::s;
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use approx::assert_abs_diff_eq;
370    use scirs2_core::ndarray::{Array, array};
371
372    #[test]
373    fn test_zonal_basic() {
374        let values = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
375        let zones = array![[1, 1, 2], [1, 2, 2], [2, 2, 2]];
376
377        let calculator = ZonalCalculator::new();
378        let result = calculator
379            .calculate(&values.view(), &zones.view())
380            .expect("Zonal statistics calculation should succeed");
381
382        assert_eq!(result.zone_ids.len(), 2);
383        assert!(result.zones.contains_key(&1));
384        assert!(result.zones.contains_key(&2));
385
386        // Zone 1: values 1, 2, 4
387        let zone1_stats = &result.zones[&1];
388        assert_abs_diff_eq!(
389            zone1_stats[&ZonalStatistic::Mean],
390            (1.0 + 2.0 + 4.0) / 3.0,
391            epsilon = 1e-10
392        );
393    }
394
395    #[test]
396    fn test_zonal_statistics() {
397        let values = array![[1.0, 2.0], [3.0, 4.0]];
398        let zones = array![[1, 1], [1, 1]];
399
400        let calculator = ZonalCalculator::new().with_statistics(vec![
401            ZonalStatistic::Mean,
402            ZonalStatistic::Min,
403            ZonalStatistic::Max,
404            ZonalStatistic::StdDev,
405        ]);
406
407        let result = calculator
408            .calculate(&values.view(), &zones.view())
409            .expect("Zonal statistics with multiple stats should succeed");
410        let zone1_stats = &result.zones[&1];
411
412        assert_abs_diff_eq!(zone1_stats[&ZonalStatistic::Mean], 2.5, epsilon = 1e-10);
413        assert_abs_diff_eq!(zone1_stats[&ZonalStatistic::Min], 1.0, epsilon = 1e-10);
414        assert_abs_diff_eq!(zone1_stats[&ZonalStatistic::Max], 4.0, epsilon = 1e-10);
415    }
416
417    #[test]
418    fn test_weighted_zonal() {
419        let values = array![[1.0, 2.0], [3.0, 4.0]];
420        let weights = array![[1.0, 1.0], [1.0, 1.0]];
421        let zones = array![[1, 1], [1, 1]];
422
423        let calculator = WeightedZonalCalculator::new();
424        let result = calculator
425            .calculate(&values.view(), &weights.view(), &zones.view())
426            .expect("Weighted zonal statistics should succeed");
427
428        let zone1_stats = &result.zones[&1];
429        assert_abs_diff_eq!(zone1_stats[&ZonalStatistic::Mean], 2.5, epsilon = 1e-10);
430    }
431
432    #[test]
433    fn test_percentile() {
434        let values = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
435        let zones = array![[1, 1, 1], [1, 1, 1]];
436
437        let calculator = ZonalCalculator::new().with_statistics(vec![
438            ZonalStatistic::Percentile(50), // Median
439            ZonalStatistic::Percentile(25),
440            ZonalStatistic::Percentile(75),
441        ]);
442
443        let result = calculator
444            .calculate(&values.view(), &zones.view())
445            .expect("Percentile calculation should succeed");
446        let zone1_stats = &result.zones[&1];
447
448        assert_abs_diff_eq!(
449            zone1_stats[&ZonalStatistic::Percentile(50)],
450            3.5,
451            epsilon = 1e-10
452        );
453    }
454}