Skip to main content

rusterize/
rasterize.rs

1use std::collections::BTreeMap;
2
3use crate::{
4    encoding::{
5        arrays::{DenseArray, SparseArray},
6        writers::{DenseArrayWriter, PixelWriter, SparseArrayWriter, ToSparseArray},
7    },
8    error::{RusterizeError, RusterizeResult},
9    prelude::{RasterDtype, RasterizeContext},
10    rasterization::{
11        burn_geometry::Burn,
12        burners::{AllTouched, AllTouchedCached, LineBurnStrategy, Standard},
13    },
14};
15use geo::Geometry;
16use ndarray::{ArrayView1, Axis};
17use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
18
19#[cfg(feature = "polars")]
20use polars::prelude::*;
21
22/// Source of values to burn onto a [`DenseArray`] or [`SparseArray`].
23#[derive(Clone)]
24pub enum FieldSource<'a, N> {
25    /// A single constant value to burn.
26    Scalar(N),
27    /// An array of values each associated to a unique geometry.
28    Array(ArrayView1<'a, N>),
29    #[cfg(feature = "polars")]
30    Column(Column),
31}
32
33impl<'a, N, T> From<&'a T> for FieldSource<'a, N>
34where
35    T: AsRef<[N]> + ?Sized,
36{
37    fn from(v: &'a T) -> Self {
38        Self::Array(ArrayView1::from(v.as_ref()))
39    }
40}
41
42macro_rules! dispatch {
43    ($all_touched:expr, $dedup:expr, $geoms:expr, $ctx:expr, $writer:expr, $idx:expr) => {
44        match ($all_touched, $dedup) {
45            (true, true) => process::<N, _, AllTouchedCached, _>($geoms, $ctx, $writer, $idx),
46            (true, false) => process::<N, _, AllTouched, _>($geoms, $ctx, $writer, $idx),
47            (false, _) => process::<N, _, Standard, _>($geoms, $ctx, $writer, $idx),
48        }
49    };
50}
51
52/// Rasterization trait. Attaches to anything that can be viewed as a [`geo::Geometry`] slice
53/// and produces a [`DenseArray`] or a [`SparseArray`].
54pub trait Rasterize {
55    fn rasterize<A: ArrayBuilder>(&self, ctx: RasterizeContext<A::Dtype>) -> RusterizeResult<A>;
56}
57
58impl<T: AsRef<[Geometry<f64>]> + ?Sized> Rasterize for T {
59    fn rasterize<A: ArrayBuilder>(&self, ctx: RasterizeContext<A::Dtype>) -> RusterizeResult<A> {
60        A::build(self.as_ref(), ctx)
61    }
62}
63
64/// [`DenseArray`] or [`SparseArray`] creation trait.
65pub trait ArrayBuilder: Sized {
66    type Dtype: RasterDtype;
67
68    fn build(geoms: &[Geometry<f64>], ctx: RasterizeContext<Self::Dtype>) -> RusterizeResult<Self>;
69}
70
71impl<N> ArrayBuilder for DenseArray<N>
72where
73    N: RasterDtype,
74{
75    type Dtype = N;
76
77    fn build(geoms: &[Geometry<f64>], ctx: RasterizeContext<Self::Dtype>) -> RusterizeResult<Self> {
78        assert_matching_len(geoms.len(), &ctx.field, ctx.by)?;
79
80        let dedup = ctx.requires_dedup();
81
82        match ctx.by {
83            Some(by) => {
84                let (groups, groups_idx) = group_keys(by);
85                let n_groups = groups.len();
86                let mut band_names = Vec::with_capacity(n_groups);
87                let mut raster = ctx.raster_info.build_raster(n_groups, ctx.background);
88
89                raster
90                    .outer_iter_mut()
91                    .into_par_iter()
92                    .zip(groups.into_par_iter())
93                    .zip(groups_idx.into_par_iter())
94                    .map(|((band, name), idxs)| {
95                        let mut writer = DenseArrayWriter::new(band, ctx.pixel_fn());
96
97                        dispatch!(ctx.all_touched, dedup, geoms, &ctx, &mut writer, idxs.iter().copied());
98
99                        name
100                    })
101                    .collect_into_vec(&mut band_names);
102
103                Ok(DenseArray::new(raster, band_names, ctx.raster_info))
104            }
105            None => {
106                let band_names = vec![String::from("band_1")];
107                let mut raster = ctx.raster_info.build_raster(1, ctx.background);
108                let mut writer = DenseArrayWriter::new(raster.index_axis_mut(Axis(0), 0), ctx.pixel_fn());
109
110                dispatch!(ctx.all_touched, dedup, geoms, &ctx, &mut writer, 0..geoms.len());
111
112                Ok(DenseArray::new(raster, band_names, ctx.raster_info))
113            }
114        }
115    }
116}
117
118impl<N> ArrayBuilder for SparseArray<N>
119where
120    N: RasterDtype,
121{
122    type Dtype = N;
123
124    fn build(geoms: &[Geometry<f64>], ctx: RasterizeContext<Self::Dtype>) -> RusterizeResult<Self> {
125        assert_matching_len(geoms.len(), &ctx.field, ctx.by)?;
126
127        let dedup = ctx.requires_dedup();
128
129        match ctx.by {
130            Some(by) => {
131                let (groups, groups_idx) = group_keys(by);
132                let mut writers = Vec::with_capacity(groups.len());
133
134                groups
135                    .into_par_iter()
136                    .zip(groups_idx.into_par_iter())
137                    .map(|(name, idxs)| {
138                        let mut writer = SparseArrayWriter::new(name);
139
140                        dispatch!(ctx.all_touched, dedup, geoms, &ctx, &mut writer, idxs.iter().copied());
141
142                        writer
143                    })
144                    .collect_into_vec(&mut writers);
145
146                Ok(writers.finish(ctx))
147            }
148            None => {
149                let mut writer = SparseArrayWriter::new(String::from("band_1"));
150
151                dispatch!(ctx.all_touched, dedup, geoms, &ctx, &mut writer, 0..geoms.len());
152
153                Ok(writer.finish(ctx))
154            }
155        }
156    }
157}
158
159/// Burn the geometries at `indices` onto `writer`.
160/// `indices` is `0..len` for a single band, or the group's geometry indexes for multiband.
161#[cfg_attr(feature = "hotpath", hotpath::measure)]
162fn process<N, W, S, I>(geoms: &[Geometry<f64>], ctx: &RasterizeContext<N>, writer: &mut W, indices: I)
163where
164    N: RasterDtype,
165    W: PixelWriter<N>,
166    S: LineBurnStrategy,
167    I: Iterator<Item = usize>,
168{
169    match &ctx.field {
170        FieldSource::Scalar(s) => {
171            for i in indices {
172                geoms[i].burn::<S>(&ctx.raster_info, *s, writer, ctx.background);
173            }
174        }
175        FieldSource::Array(arr) => {
176            for i in indices {
177                geoms[i].burn::<S>(&ctx.raster_info, arr[i], writer, ctx.background);
178            }
179        }
180        #[cfg(feature = "polars")]
181        FieldSource::Column(col) => {
182            let ca = col.as_materialized_series().unpack::<N::ChunkedArrayType>().unwrap();
183            if let Ok(slice) = ca.cont_slice() {
184                for i in indices {
185                    geoms[i].burn::<S>(&ctx.raster_info, slice[i], writer, ctx.background);
186                }
187            } else {
188                for i in indices {
189                    if let Some(fv) = ca.get(i) {
190                        geoms[i].burn::<S>(&ctx.raster_info, fv, writer, ctx.background);
191                    }
192                }
193            }
194        }
195    }
196}
197
198/// Group `by` keys into (band name, geometry indexes) pairs, sorted by key.
199fn group_keys(by: &[String]) -> (Vec<String>, Vec<Vec<usize>>) {
200    let mut groups: BTreeMap<&String, Vec<usize>> = BTreeMap::new();
201    for (i, key) in by.iter().enumerate() {
202        groups.entry(key).or_default().push(i);
203    }
204    groups.into_iter().map(|(k, idxs)| (k.clone(), idxs)).unzip()
205}
206
207/// Validate length of geometry, field, and by. Must match.
208fn assert_matching_len<N>(n_geoms: usize, field: &FieldSource<N>, by: Option<&[String]>) -> RusterizeResult<()> {
209    let field_len = match field {
210        FieldSource::Array(arr) => Some(arr.len()),
211        #[cfg(feature = "polars")]
212        FieldSource::Column(col) => Some(col.len()),
213        FieldSource::Scalar(_) => None,
214    };
215
216    if let Some(field_len) = field_len
217        && field_len != n_geoms
218    {
219        return Err(RusterizeError::ValueError("Geometry and field lengths must match"));
220    }
221
222    if let Some(by) = by
223        && by.len() != n_geoms
224    {
225        return Err(RusterizeError::ValueError("Geometry and by lengths must match"));
226    }
227
228    Ok(())
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::{geo::raster::RasterInfo, rasterization::pixel_functions::PixelFunction};
235    use geo::{Geometry, LineString, Polygon};
236
237    fn raster_4x4() -> RasterInfo {
238        RasterInfo {
239            ncols: 4,
240            nrows: 4,
241            xmin: 0.0,
242            xmax: 4.0,
243            ymin: 0.0,
244            ymax: 4.0,
245            xres: 1.0,
246            yres: 1.0,
247            epsg: None,
248        }
249    }
250
251    #[test]
252    fn dense_burns_a_polygon() {
253        let poly = Polygon::new(
254            LineString::from(vec![(0.5, 0.5), (3.5, 0.5), (3.5, 3.5), (0.5, 3.5), (0.5, 0.5)]),
255            vec![],
256        );
257        let geoms = vec![Geometry::Polygon(poly)];
258        let ctx = RasterizeContext {
259            raster_info: raster_4x4(),
260            field: FieldSource::Scalar(1.0_f64),
261            by: None,
262            pixel_fn: PixelFunction::Last,
263            background: 0.0,
264            all_touched: false,
265        };
266
267        let out: DenseArray<f64> = geoms.rasterize(ctx).unwrap();
268        let (raster, _, _) = out.into_parts();
269        assert_eq!(raster.shape(), &[1, 4, 4]);
270        assert!(
271            raster.iter().any(|&v| v == 1.0),
272            "polygon should burn at least one cell"
273        );
274    }
275
276    #[test]
277    fn multiband_burns_only_its_group() {
278        use geo::Point;
279        use ndarray::Array1;
280        let geoms = vec![
281            Geometry::Point(Point::new(0.5, 0.5)),
282            Geometry::Point(Point::new(3.5, 3.5)),
283        ];
284        let by = [String::from("a"), String::from("b")];
285        let vals = Array1::from(vec![1.0_f64, 2.0]);
286        let ctx = RasterizeContext {
287            raster_info: raster_4x4(),
288            field: FieldSource::Array(vals.view()),
289            by: Some(by.as_ref()),
290            pixel_fn: PixelFunction::Last,
291            background: 0.0,
292            all_touched: false,
293        };
294
295        let out: DenseArray<f64> = geoms.rasterize(ctx).unwrap();
296        let (raster, _, _) = out.into_parts();
297        assert_eq!(raster.shape(), &[2, 4, 4]);
298
299        for band in raster.outer_iter() {
300            let has1 = band.iter().any(|&v| v == 1.0);
301            let has2 = band.iter().any(|&v| v == 2.0);
302            assert!(has1 ^ has2, "a band burned geometries outside its group");
303        }
304    }
305
306    #[test]
307    fn group_keys_groups_and_names() {
308        let by = [String::from("b"), String::from("a"), String::from("b")];
309        let (names, idx) = group_keys(&by);
310        let mut pairs: Vec<(String, Vec<usize>)> = names.into_iter().zip(idx).collect();
311        pairs.sort();
312        assert_eq!(pairs, vec![("a".to_string(), vec![1]), ("b".to_string(), vec![0, 2])]);
313    }
314}