Skip to main content

oxigdal_temporal/
stack.rs

1//! Raster Stack Operations Module
2//!
3//! This module provides operations for stacking multiple rasters together,
4//! including multi-band stacking, temporal stacking, and stack transformations.
5
6use crate::error::{Result, TemporalError};
7#[cfg(feature = "timeseries")]
8use crate::timeseries::TimeSeriesRaster;
9use scirs2_core::ndarray::{Array3, Array4, Axis};
10use serde::{Deserialize, Serialize};
11use tracing::{debug, info};
12
13/// Stack configuration
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct StackConfig {
16    /// Stack along which axis (0=temporal, 1=bands, 2=height, 3=width)
17    pub axis: usize,
18    /// Interpolation method for mismatched dimensions
19    pub interpolation: InterpolationMethod,
20    /// Fill value for missing data
21    pub fill_value: Option<f64>,
22}
23
24impl Default for StackConfig {
25    fn default() -> Self {
26        Self {
27            axis: 0,
28            interpolation: InterpolationMethod::Nearest,
29            fill_value: Some(f64::NAN),
30        }
31    }
32}
33
34/// Interpolation method for resampling
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum InterpolationMethod {
37    /// Nearest neighbor
38    Nearest,
39    /// Bilinear interpolation
40    Bilinear,
41    /// Cubic interpolation
42    Cubic,
43}
44
45/// Multi-dimensional raster stack
46///
47/// Represents a stack of rasters organized in a 4D array:
48/// (time, height, width, bands)
49#[derive(Debug, Clone)]
50pub struct RasterStack {
51    /// 4D data array: (time, height, width, bands)
52    data: Array4<f64>,
53    /// Stack metadata
54    metadata: StackMetadata,
55}
56
57/// Stack metadata
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct StackMetadata {
60    /// Number of time steps
61    pub n_time: usize,
62    /// Height (rows)
63    pub height: usize,
64    /// Width (columns)
65    pub width: usize,
66    /// Number of bands
67    pub n_bands: usize,
68    /// Band names
69    pub band_names: Vec<String>,
70    /// NoData value
71    pub nodata: Option<f64>,
72}
73
74impl RasterStack {
75    /// Create new raster stack from 4D array
76    ///
77    /// # Errors
78    /// Returns error if array dimensions are invalid
79    pub fn new(data: Array4<f64>) -> Result<Self> {
80        let shape = data.shape();
81        if shape.len() != 4 {
82            return Err(TemporalError::dimension_mismatch(
83                "4D array",
84                format!("{}D array", shape.len()),
85            ));
86        }
87
88        let metadata = StackMetadata {
89            n_time: shape[0],
90            height: shape[1],
91            width: shape[2],
92            n_bands: shape[3],
93            band_names: (0..shape[3]).map(|i| format!("Band_{}", i + 1)).collect(),
94            nodata: None,
95        };
96
97        Ok(Self { data, metadata })
98    }
99
100    /// Create raster stack from time series
101    ///
102    /// # Errors
103    /// Returns error if time series is empty or data not loaded
104    #[cfg(feature = "timeseries")]
105    pub fn from_timeseries(ts: &TimeSeriesRaster) -> Result<Self> {
106        if ts.is_empty() {
107            return Err(TemporalError::insufficient_data("Time series is empty"));
108        }
109
110        // Get shape from first entry
111        let (height, width, n_bands) = ts
112            .expected_shape()
113            .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
114
115        let n_time = ts.len();
116
117        // Initialize 4D array
118        let mut data = Array4::zeros((n_time, height, width, n_bands));
119
120        // Fill data from time series
121        for (t, (_, entry)) in ts.iter().enumerate() {
122            let entry_data = entry.data.as_ref().ok_or_else(|| {
123                TemporalError::invalid_input("Data not loaded in time series entry")
124            })?;
125
126            // Copy data to stack
127            for i in 0..height {
128                for j in 0..width {
129                    for k in 0..n_bands {
130                        data[[t, i, j, k]] = entry_data[[i, j, k]];
131                    }
132                }
133            }
134        }
135
136        let metadata = StackMetadata {
137            n_time,
138            height,
139            width,
140            n_bands,
141            band_names: (0..n_bands).map(|i| format!("Band_{}", i + 1)).collect(),
142            nodata: None,
143        };
144
145        info!(
146            "Created raster stack with shape ({}, {}, {}, {})",
147            n_time, height, width, n_bands
148        );
149
150        Ok(Self { data, metadata })
151    }
152
153    /// Get stack shape (time, height, width, bands)
154    #[must_use]
155    pub fn shape(&self) -> (usize, usize, usize, usize) {
156        (
157            self.metadata.n_time,
158            self.metadata.height,
159            self.metadata.width,
160            self.metadata.n_bands,
161        )
162    }
163
164    /// Get reference to underlying data
165    #[must_use]
166    pub fn data(&self) -> &Array4<f64> {
167        &self.data
168    }
169
170    /// Get mutable reference to underlying data
171    pub fn data_mut(&mut self) -> &mut Array4<f64> {
172        &mut self.data
173    }
174
175    /// Get metadata
176    #[must_use]
177    pub fn metadata(&self) -> &StackMetadata {
178        &self.metadata
179    }
180
181    /// Set band names
182    pub fn set_band_names(&mut self, names: Vec<String>) -> Result<()> {
183        if names.len() != self.metadata.n_bands {
184            return Err(TemporalError::dimension_mismatch(
185                format!("{} bands", self.metadata.n_bands),
186                format!("{} names", names.len()),
187            ));
188        }
189        self.metadata.band_names = names;
190        Ok(())
191    }
192
193    /// Set nodata value
194    pub fn set_nodata(&mut self, nodata: f64) {
195        self.metadata.nodata = Some(nodata);
196    }
197
198    /// Extract temporal slice at specific time index
199    ///
200    /// # Errors
201    /// Returns error if time index is out of bounds
202    pub fn get_time_slice(&self, time_index: usize) -> Result<Array3<f64>> {
203        if time_index >= self.metadata.n_time {
204            return Err(TemporalError::time_index_out_of_bounds(
205                time_index,
206                0,
207                self.metadata.n_time,
208            ));
209        }
210
211        Ok(self.data.index_axis(Axis(0), time_index).to_owned())
212    }
213
214    /// Extract spatial slice for specific band across all time
215    ///
216    /// # Errors
217    /// Returns error if band index is out of bounds
218    pub fn get_band_timeseries(&self, band_index: usize) -> Result<Array3<f64>> {
219        if band_index >= self.metadata.n_bands {
220            return Err(TemporalError::invalid_parameter(
221                "band_index",
222                format!(
223                    "index {} out of bounds (max: {})",
224                    band_index,
225                    self.metadata.n_bands - 1
226                ),
227            ));
228        }
229
230        // Extract (time, height, width) for specific band
231        let mut result = Array3::zeros((
232            self.metadata.n_time,
233            self.metadata.height,
234            self.metadata.width,
235        ));
236
237        for t in 0..self.metadata.n_time {
238            for i in 0..self.metadata.height {
239                for j in 0..self.metadata.width {
240                    result[[t, i, j]] = self.data[[t, i, j, band_index]];
241                }
242            }
243        }
244
245        Ok(result)
246    }
247
248    /// Extract pixel time series at specific location for specific band
249    ///
250    /// # Errors
251    /// Returns error if coordinates are out of bounds
252    pub fn get_pixel_timeseries(&self, row: usize, col: usize, band: usize) -> Result<Vec<f64>> {
253        if row >= self.metadata.height {
254            return Err(TemporalError::invalid_parameter(
255                "row",
256                format!(
257                    "index {} out of bounds (max: {})",
258                    row,
259                    self.metadata.height - 1
260                ),
261            ));
262        }
263        if col >= self.metadata.width {
264            return Err(TemporalError::invalid_parameter(
265                "col",
266                format!(
267                    "index {} out of bounds (max: {})",
268                    col,
269                    self.metadata.width - 1
270                ),
271            ));
272        }
273        if band >= self.metadata.n_bands {
274            return Err(TemporalError::invalid_parameter(
275                "band",
276                format!(
277                    "index {} out of bounds (max: {})",
278                    band,
279                    self.metadata.n_bands - 1
280                ),
281            ));
282        }
283
284        let mut values = Vec::with_capacity(self.metadata.n_time);
285        for t in 0..self.metadata.n_time {
286            values.push(self.data[[t, row, col, band]]);
287        }
288
289        Ok(values)
290    }
291
292    /// Stack multiple bands together
293    ///
294    /// # Errors
295    /// Returns error if shapes don't match
296    pub fn stack_bands(bands: Vec<Array3<f64>>) -> Result<Self> {
297        if bands.is_empty() {
298            return Err(TemporalError::insufficient_data("No bands to stack"));
299        }
300
301        // Check all bands have same shape
302        let first_shape = bands[0].shape();
303        for (i, band) in bands.iter().enumerate().skip(1) {
304            if band.shape() != first_shape {
305                return Err(TemporalError::dimension_mismatch(
306                    format!("{:?}", first_shape),
307                    format!("{:?} (band {})", band.shape(), i),
308                ));
309            }
310        }
311
312        let n_time = first_shape[0];
313        let height = first_shape[1];
314        let width = first_shape[2];
315        let n_bands = bands.len();
316
317        // Create 4D array
318        let mut data = Array4::zeros((n_time, height, width, n_bands));
319
320        for (band_idx, band_data) in bands.iter().enumerate() {
321            for t in 0..n_time {
322                for i in 0..height {
323                    for j in 0..width {
324                        data[[t, i, j, band_idx]] = band_data[[t, i, j]];
325                    }
326                }
327            }
328        }
329
330        let metadata = StackMetadata {
331            n_time,
332            height,
333            width,
334            n_bands,
335            band_names: (0..n_bands).map(|i| format!("Band_{}", i + 1)).collect(),
336            nodata: None,
337        };
338
339        debug!(
340            "Stacked {} bands into shape ({}, {}, {}, {})",
341            n_bands, n_time, height, width, n_bands
342        );
343
344        Ok(Self { data, metadata })
345    }
346
347    /// Concatenate stacks along time axis
348    ///
349    /// # Errors
350    /// Returns error if spatial dimensions don't match
351    pub fn concatenate_time(stacks: Vec<Self>) -> Result<Self> {
352        if stacks.is_empty() {
353            return Err(TemporalError::insufficient_data("No stacks to concatenate"));
354        }
355
356        // Check all stacks have same spatial dimensions and bands
357        let first = &stacks[0];
358        let (_, height, width, n_bands) = first.shape();
359
360        for (i, stack) in stacks.iter().enumerate().skip(1) {
361            let (_, h, w, b) = stack.shape();
362            if h != height || w != width || b != n_bands {
363                return Err(TemporalError::dimension_mismatch(
364                    format!("(?, {}, {}, {})", height, width, n_bands),
365                    format!("(?, {}, {}, {}) at stack {}", h, w, b, i),
366                ));
367            }
368        }
369
370        // Calculate total time steps
371        let total_time: usize = stacks.iter().map(|s| s.metadata.n_time).sum();
372
373        // Create concatenated array
374        let mut data = Array4::zeros((total_time, height, width, n_bands));
375        let mut current_time = 0;
376
377        for stack in &stacks {
378            let stack_time = stack.metadata.n_time;
379            for t in 0..stack_time {
380                for i in 0..height {
381                    for j in 0..width {
382                        for k in 0..n_bands {
383                            data[[current_time + t, i, j, k]] = stack.data[[t, i, j, k]];
384                        }
385                    }
386                }
387            }
388            current_time += stack_time;
389        }
390
391        let metadata = StackMetadata {
392            n_time: total_time,
393            height,
394            width,
395            n_bands,
396            band_names: first.metadata.band_names.clone(),
397            nodata: first.metadata.nodata,
398        };
399
400        info!(
401            "Concatenated {} stacks into shape ({}, {}, {}, {})",
402            stacks.len(),
403            total_time,
404            height,
405            width,
406            n_bands
407        );
408
409        Ok(Self { data, metadata })
410    }
411
412    /// Subset stack by time range
413    ///
414    /// # Errors
415    /// Returns error if indices are out of bounds
416    pub fn subset_time(&self, start: usize, end: usize) -> Result<Self> {
417        if start >= end {
418            return Err(TemporalError::invalid_time_range(
419                start.to_string(),
420                end.to_string(),
421            ));
422        }
423        if end > self.metadata.n_time {
424            return Err(TemporalError::time_index_out_of_bounds(
425                end,
426                0,
427                self.metadata.n_time,
428            ));
429        }
430
431        let n_time = end - start;
432        let mut data = Array4::zeros((
433            n_time,
434            self.metadata.height,
435            self.metadata.width,
436            self.metadata.n_bands,
437        ));
438
439        for (t_out, t_in) in (start..end).enumerate() {
440            for i in 0..self.metadata.height {
441                for j in 0..self.metadata.width {
442                    for k in 0..self.metadata.n_bands {
443                        data[[t_out, i, j, k]] = self.data[[t_in, i, j, k]];
444                    }
445                }
446            }
447        }
448
449        let metadata = StackMetadata {
450            n_time,
451            height: self.metadata.height,
452            width: self.metadata.width,
453            n_bands: self.metadata.n_bands,
454            band_names: self.metadata.band_names.clone(),
455            nodata: self.metadata.nodata,
456        };
457
458        Ok(Self { data, metadata })
459    }
460
461    /// Subset stack by band indices
462    ///
463    /// # Errors
464    /// Returns error if any band index is out of bounds
465    pub fn subset_bands(&self, band_indices: &[usize]) -> Result<Self> {
466        if band_indices.is_empty() {
467            return Err(TemporalError::insufficient_data("No bands selected"));
468        }
469
470        // Validate all indices
471        for &idx in band_indices {
472            if idx >= self.metadata.n_bands {
473                return Err(TemporalError::invalid_parameter(
474                    "band_index",
475                    format!(
476                        "index {} out of bounds (max: {})",
477                        idx,
478                        self.metadata.n_bands - 1
479                    ),
480                ));
481            }
482        }
483
484        let n_bands = band_indices.len();
485        let mut data = Array4::zeros((
486            self.metadata.n_time,
487            self.metadata.height,
488            self.metadata.width,
489            n_bands,
490        ));
491
492        for t in 0..self.metadata.n_time {
493            for i in 0..self.metadata.height {
494                for j in 0..self.metadata.width {
495                    for (k_out, &k_in) in band_indices.iter().enumerate() {
496                        data[[t, i, j, k_out]] = self.data[[t, i, j, k_in]];
497                    }
498                }
499            }
500        }
501
502        let band_names = band_indices
503            .iter()
504            .map(|&i| self.metadata.band_names[i].clone())
505            .collect();
506
507        let metadata = StackMetadata {
508            n_time: self.metadata.n_time,
509            height: self.metadata.height,
510            width: self.metadata.width,
511            n_bands,
512            band_names,
513            nodata: self.metadata.nodata,
514        };
515
516        Ok(Self { data, metadata })
517    }
518
519    /// Apply function to each pixel time series
520    ///
521    /// # Errors
522    /// Returns error if function fails
523    pub fn apply_temporal<F>(&self, func: F) -> Result<Array3<f64>>
524    where
525        F: Fn(&[f64]) -> f64,
526    {
527        let mut result = Array3::zeros((
528            self.metadata.height,
529            self.metadata.width,
530            self.metadata.n_bands,
531        ));
532
533        for i in 0..self.metadata.height {
534            for j in 0..self.metadata.width {
535                for k in 0..self.metadata.n_bands {
536                    let timeseries: Vec<f64> = (0..self.metadata.n_time)
537                        .map(|t| self.data[[t, i, j, k]])
538                        .collect();
539                    result[[i, j, k]] = func(&timeseries);
540                }
541            }
542        }
543
544        Ok(result)
545    }
546
547    /// Calculate mean across time dimension
548    ///
549    /// # Errors
550    /// Returns error if calculation fails
551    pub fn mean_temporal(&self) -> Result<Array3<f64>> {
552        self.apply_temporal(|values| values.iter().sum::<f64>() / values.len() as f64)
553    }
554
555    /// Calculate median across time dimension
556    ///
557    /// # Errors
558    /// Returns error if calculation fails
559    pub fn median_temporal(&self) -> Result<Array3<f64>> {
560        self.apply_temporal(|values| {
561            let mut sorted = values.to_vec();
562            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
563            let mid = sorted.len() / 2;
564            if sorted.len() % 2 == 0 {
565                (sorted[mid - 1] + sorted[mid]) / 2.0
566            } else {
567                sorted[mid]
568            }
569        })
570    }
571
572    /// Calculate standard deviation across time dimension
573    ///
574    /// # Errors
575    /// Returns error if calculation fails
576    pub fn std_temporal(&self) -> Result<Array3<f64>> {
577        self.apply_temporal(|values| {
578            let mean = values.iter().sum::<f64>() / values.len() as f64;
579            let variance =
580                values.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
581            variance.sqrt()
582        })
583    }
584
585    /// Calculate minimum across time dimension
586    ///
587    /// # Errors
588    /// Returns error if calculation fails
589    pub fn min_temporal(&self) -> Result<Array3<f64>> {
590        self.apply_temporal(|values| {
591            values
592                .iter()
593                .copied()
594                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
595                .unwrap_or(f64::NAN)
596        })
597    }
598
599    /// Calculate maximum across time dimension
600    ///
601    /// # Errors
602    /// Returns error if calculation fails
603    pub fn max_temporal(&self) -> Result<Array3<f64>> {
604        self.apply_temporal(|values| {
605            values
606                .iter()
607                .copied()
608                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
609                .unwrap_or(f64::NAN)
610        })
611    }
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617    use approx::assert_abs_diff_eq;
618
619    #[test]
620    fn test_raster_stack_creation() {
621        let data = Array4::zeros((10, 100, 100, 3));
622        let stack = RasterStack::new(data).expect("should create stack");
623        assert_eq!(stack.shape(), (10, 100, 100, 3));
624    }
625
626    #[test]
627    fn test_get_time_slice() {
628        let mut data = Array4::zeros((10, 5, 5, 2));
629        data[[3, 2, 2, 0]] = 42.0;
630
631        let stack = RasterStack::new(data).expect("should create stack");
632        let slice = stack.get_time_slice(3).expect("should get slice");
633
634        assert_eq!(slice.shape(), &[5, 5, 2]);
635        assert_abs_diff_eq!(slice[[2, 2, 0]], 42.0);
636    }
637
638    #[test]
639    fn test_get_pixel_timeseries() {
640        let mut data = Array4::zeros((10, 5, 5, 2));
641        for t in 0..10 {
642            data[[t, 2, 3, 1]] = t as f64;
643        }
644
645        let stack = RasterStack::new(data).expect("should create stack");
646        let ts = stack
647            .get_pixel_timeseries(2, 3, 1)
648            .expect("should get timeseries");
649
650        assert_eq!(ts.len(), 10);
651        for (i, &val) in ts.iter().enumerate() {
652            assert_abs_diff_eq!(val, i as f64);
653        }
654    }
655
656    #[test]
657    fn test_stack_bands() {
658        let band1 = Array3::from_elem((5, 10, 10), 1.0);
659        let band2 = Array3::from_elem((5, 10, 10), 2.0);
660        let band3 = Array3::from_elem((5, 10, 10), 3.0);
661
662        let stack = RasterStack::stack_bands(vec![band1, band2, band3]).expect("should stack");
663
664        assert_eq!(stack.shape(), (5, 10, 10, 3));
665        assert_abs_diff_eq!(stack.data()[[0, 0, 0, 0]], 1.0);
666        assert_abs_diff_eq!(stack.data()[[0, 0, 0, 1]], 2.0);
667        assert_abs_diff_eq!(stack.data()[[0, 0, 0, 2]], 3.0);
668    }
669
670    #[test]
671    fn test_concatenate_time() {
672        let data1 = Array4::from_elem((5, 10, 10, 2), 1.0);
673        let stack1 = RasterStack::new(data1).expect("should create");
674
675        let data2 = Array4::from_elem((3, 10, 10, 2), 2.0);
676        let stack2 = RasterStack::new(data2).expect("should create");
677
678        let concatenated =
679            RasterStack::concatenate_time(vec![stack1, stack2]).expect("should concatenate");
680
681        assert_eq!(concatenated.shape(), (8, 10, 10, 2));
682    }
683
684    #[test]
685    fn test_subset_time() {
686        let data = Array4::zeros((10, 5, 5, 2));
687        let stack = RasterStack::new(data).expect("should create");
688
689        let subset = stack.subset_time(2, 7).expect("should subset");
690        assert_eq!(subset.shape(), (5, 5, 5, 2));
691    }
692
693    #[test]
694    fn test_subset_bands() {
695        let data = Array4::zeros((10, 5, 5, 5));
696        let stack = RasterStack::new(data).expect("should create");
697
698        let subset = stack.subset_bands(&[0, 2, 4]).expect("should subset");
699        assert_eq!(subset.shape(), (10, 5, 5, 3));
700    }
701
702    #[test]
703    fn test_mean_temporal() {
704        let mut data = Array4::zeros((3, 2, 2, 1));
705        data[[0, 0, 0, 0]] = 1.0;
706        data[[1, 0, 0, 0]] = 2.0;
707        data[[2, 0, 0, 0]] = 3.0;
708
709        let stack = RasterStack::new(data).expect("should create");
710        let mean = stack.mean_temporal().expect("should calculate mean");
711
712        assert_abs_diff_eq!(mean[[0, 0, 0]], 2.0);
713    }
714
715    #[test]
716    fn test_median_temporal() {
717        let mut data = Array4::zeros((5, 2, 2, 1));
718        data[[0, 0, 0, 0]] = 1.0;
719        data[[1, 0, 0, 0]] = 2.0;
720        data[[2, 0, 0, 0]] = 3.0;
721        data[[3, 0, 0, 0]] = 4.0;
722        data[[4, 0, 0, 0]] = 5.0;
723
724        let stack = RasterStack::new(data).expect("should create");
725        let median = stack.median_temporal().expect("should calculate median");
726
727        assert_abs_diff_eq!(median[[0, 0, 0]], 3.0);
728    }
729
730    #[test]
731    fn test_min_max_temporal() {
732        let mut data = Array4::zeros((5, 2, 2, 1));
733        data[[0, 0, 0, 0]] = 1.0;
734        data[[1, 0, 0, 0]] = 5.0;
735        data[[2, 0, 0, 0]] = 3.0;
736        data[[3, 0, 0, 0]] = 2.0;
737        data[[4, 0, 0, 0]] = 4.0;
738
739        let stack = RasterStack::new(data).expect("should create");
740        let min = stack.min_temporal().expect("should calculate min");
741        let max = stack.max_temporal().expect("should calculate max");
742
743        assert_abs_diff_eq!(min[[0, 0, 0]], 1.0);
744        assert_abs_diff_eq!(max[[0, 0, 0]], 5.0);
745    }
746}