Skip to main content

rusterize/encoding/
arrays.rs

1use crate::{
2    geo::raster::RasterInfo,
3    prelude::{RasterDtype, RasterizeContext},
4    rasterization::pixel_functions::PixelFn,
5};
6use ndarray::Array3;
7use num_traits::Num;
8use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
9
10/// A materialized 3-dimensional array containing the burned geometries and spatial information.
11pub struct DenseArray<N> {
12    raster: Array3<N>,
13    band_names: Vec<String>,
14    raster_info: RasterInfo,
15}
16
17impl<N: Num> DenseArray<N> {
18    pub(crate) fn new(raster: Array3<N>, band_names: Vec<String>, raster_info: RasterInfo) -> Self {
19        Self {
20            raster,
21            band_names,
22            raster_info,
23        }
24    }
25
26    /// Consume self and extract all fields of the DenseArray.
27    pub fn into_parts(self) -> (Array3<N>, Vec<String>, RasterInfo) {
28        (self.raster, self.band_names, self.raster_info)
29    }
30
31    /// Sorted band names for the array. Defaults to "band_1" for a single band.
32    pub fn band_names(&self) -> &[String] {
33        &self.band_names
34    }
35
36    /// Spatial information associated with the array.
37    pub fn raster_info(&self) -> &RasterInfo {
38        &self.raster_info
39    }
40}
41
42/// Triplets of (row, col, value) for all bands as a contiguous block.
43/// Used to store inside a [`SparseArray`].
44struct Triplets<N> {
45    rows: Vec<u64>,
46    cols: Vec<u64>,
47    data: Vec<N>,
48}
49
50impl<N: Num> Triplets<N> {
51    fn new(rows: Vec<u64>, cols: Vec<u64>, data: Vec<N>) -> Self {
52        Self { rows, cols, data }
53    }
54}
55
56/// A sparse array in COOordinate format storing the band/row/col value triplets.
57/// of all burned [`geo::Geometry`].
58pub struct SparseArray<N> {
59    band_names: Vec<String>,
60    triplets: Triplets<N>,
61    lengths: Vec<usize>,
62    raster_info: RasterInfo,
63    pxfn: PixelFn<N>,
64    background: N,
65}
66
67impl<N> SparseArray<N>
68where
69    N: RasterDtype,
70{
71    pub(crate) fn new(
72        band_names: Vec<String>,
73        rows: Vec<u64>,
74        cols: Vec<u64>,
75        data: Vec<N>,
76        lengths: Vec<usize>,
77        ctx: RasterizeContext<N>,
78    ) -> Self {
79        let pxfn = ctx.pixel_fn();
80        let background = ctx.background;
81
82        Self {
83            band_names,
84            triplets: Triplets::new(rows, cols, data),
85            lengths,
86            raster_info: ctx.raster_info,
87            pxfn,
88            background,
89        }
90    }
91
92    /// Get the band names associated with this array.
93    pub fn band_names(&self) -> &[String] {
94        &self.band_names
95    }
96
97    /// Materialize a [`ndarray::Array3`] from this. Drops spatial information.
98    pub fn build_array(&self) -> Array3<N> {
99        let mut raster = self.raster_info.build_raster(self.band_names.len(), self.background);
100
101        let rows = self.triplets.rows.as_slice();
102        let cols = self.triplets.cols.as_slice();
103        let data = self.triplets.data.as_slice();
104
105        // per-band start offset into the contiguous triplet arrays
106        let offsets = self
107            .lengths
108            .iter()
109            .scan(0, |state, &n| {
110                let start = *state;
111                *state += n;
112                Some(start)
113            })
114            .collect::<Vec<usize>>();
115
116        raster
117            .outer_iter_mut()
118            .into_par_iter()
119            .zip(self.lengths.par_iter())
120            .zip(offsets.par_iter())
121            .for_each(|((mut band, n), &off)| {
122                let end = off + *n;
123                let band_rows = &rows[off..end];
124                let band_cols = &cols[off..end];
125                let band_data = &data[off..end];
126
127                for ((band_row, band_col), band_value) in band_rows.iter().zip(band_cols).zip(band_data) {
128                    (self.pxfn)(
129                        &mut band,
130                        *band_row as usize,
131                        *band_col as usize,
132                        *band_value,
133                        self.background,
134                    );
135                }
136            });
137        raster
138    }
139
140    pub fn extent(&self) -> (f64, f64, f64, f64) {
141        (
142            self.raster_info.xmin,
143            self.raster_info.ymin,
144            self.raster_info.xmax,
145            self.raster_info.ymax,
146        )
147    }
148
149    pub fn shape(&self) -> (usize, usize) {
150        (self.raster_info.nrows, self.raster_info.ncols)
151    }
152
153    pub fn resolution(&self) -> (f64, f64) {
154        (self.raster_info.xres, self.raster_info.yres)
155    }
156
157    /// Get spatial information associated with this array.
158    pub fn raster_info(&self) -> &RasterInfo {
159        &self.raster_info
160    }
161
162    pub fn epsg(&self) -> Option<u16> {
163        self.raster_info.epsg
164    }
165}
166
167#[cfg(feature = "polars")]
168mod feature_gated {
169    use super::SparseArray;
170    use crate::prelude::PolarsHandler;
171    use num_traits::Num;
172    use polars::prelude::*;
173
174    impl<N> SparseArray<N>
175    where
176        N: Num + Copy + PolarsHandler,
177    {
178        /// Convert this to a [`polars::prelude::DataFrame`].
179        pub fn to_frame(&self) -> DataFrame {
180            let mut columns: Vec<Column> = Vec::new();
181
182            // add bands for multiband raster
183            if self.lengths.len() > 1 {
184                let bands = self
185                    .lengths
186                    .iter()
187                    .enumerate()
188                    .flat_map(|(i, v)| std::iter::repeat_n(i + 1, *v))
189                    .map(|b| b as u64)
190                    .collect::<Vec<u64>>();
191                let bands_column = Column::new("band".into(), bands);
192                columns.push(bands_column);
193            }
194
195            columns.push(Column::new("row".into(), self.triplets.rows.as_slice()));
196            columns.push(Column::new("col".into(), self.triplets.cols.as_slice()));
197
198            let height = self.triplets.data.len();
199            columns.push(N::from_named_vec("values", &self.triplets.data));
200
201            DataFrame::new(height, columns).unwrap()
202        }
203    }
204}