1use std::collections::BTreeMap;
2
3use crate::{
4 encoding::{
5 arrays::{DenseArray, SparseArray},
6 writers::{DenseArrayWriter, PixelWriter, SparseArrayWriter, ToSparseArray},
7 },
8 error::{RusterizeError, RusterizeResult},
9 prelude::{RasterDtype, RasterizeContext},
10 rasterization::{
11 burn_geometry::Burn,
12 burners::{AllTouched, AllTouchedCached, LineBurnStrategy, Standard},
13 },
14};
15use geo::Geometry;
16use ndarray::{ArrayView1, Axis};
17use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
18
19#[cfg(feature = "polars")]
20use polars::prelude::*;
21
22#[derive(Clone)]
24pub enum FieldSource<'a, N> {
25 Scalar(N),
27 Array(ArrayView1<'a, N>),
29 #[cfg(feature = "polars")]
30 Column(Column),
31}
32
33impl<'a, N, T> From<&'a T> for FieldSource<'a, N>
34where
35 T: AsRef<[N]> + ?Sized,
36{
37 fn from(v: &'a T) -> Self {
38 Self::Array(ArrayView1::from(v.as_ref()))
39 }
40}
41
42macro_rules! dispatch {
43 ($all_touched:expr, $dedup:expr, $geoms:expr, $ctx:expr, $writer:expr, $idx:expr) => {
44 match ($all_touched, $dedup) {
45 (true, true) => process::<N, _, AllTouchedCached, _>($geoms, $ctx, $writer, $idx),
46 (true, false) => process::<N, _, AllTouched, _>($geoms, $ctx, $writer, $idx),
47 (false, _) => process::<N, _, Standard, _>($geoms, $ctx, $writer, $idx),
48 }
49 };
50}
51
52pub trait Rasterize {
55 fn rasterize<A: ArrayBuilder>(&self, ctx: RasterizeContext<A::Dtype>) -> RusterizeResult<A>;
56}
57
58impl<T: AsRef<[Geometry<f64>]> + ?Sized> Rasterize for T {
59 fn rasterize<A: ArrayBuilder>(&self, ctx: RasterizeContext<A::Dtype>) -> RusterizeResult<A> {
60 A::build(self.as_ref(), ctx)
61 }
62}
63
64pub trait ArrayBuilder: Sized {
66 type Dtype: RasterDtype;
67
68 fn build(geoms: &[Geometry<f64>], ctx: RasterizeContext<Self::Dtype>) -> RusterizeResult<Self>;
69}
70
71impl<N> ArrayBuilder for DenseArray<N>
72where
73 N: RasterDtype,
74{
75 type Dtype = N;
76
77 fn build(geoms: &[Geometry<f64>], ctx: RasterizeContext<Self::Dtype>) -> RusterizeResult<Self> {
78 assert_matching_len(geoms.len(), &ctx.field, ctx.by)?;
79
80 let dedup = ctx.requires_dedup();
81
82 match ctx.by {
83 Some(by) => {
84 let (groups, groups_idx) = group_keys(by);
85 let n_groups = groups.len();
86 let mut band_names = Vec::with_capacity(n_groups);
87 let mut raster = ctx.raster_info.build_raster(n_groups, ctx.background);
88
89 raster
90 .outer_iter_mut()
91 .into_par_iter()
92 .zip(groups.into_par_iter())
93 .zip(groups_idx.into_par_iter())
94 .map(|((band, name), idxs)| {
95 let mut writer = DenseArrayWriter::new(band, ctx.pixel_fn());
96
97 dispatch!(ctx.all_touched, dedup, geoms, &ctx, &mut writer, idxs.iter().copied());
98
99 name
100 })
101 .collect_into_vec(&mut band_names);
102
103 Ok(DenseArray::new(raster, band_names, ctx.raster_info))
104 }
105 None => {
106 let band_names = vec![String::from("band_1")];
107 let mut raster = ctx.raster_info.build_raster(1, ctx.background);
108 let mut writer = DenseArrayWriter::new(raster.index_axis_mut(Axis(0), 0), ctx.pixel_fn());
109
110 dispatch!(ctx.all_touched, dedup, geoms, &ctx, &mut writer, 0..geoms.len());
111
112 Ok(DenseArray::new(raster, band_names, ctx.raster_info))
113 }
114 }
115 }
116}
117
118impl<N> ArrayBuilder for SparseArray<N>
119where
120 N: RasterDtype,
121{
122 type Dtype = N;
123
124 fn build(geoms: &[Geometry<f64>], ctx: RasterizeContext<Self::Dtype>) -> RusterizeResult<Self> {
125 assert_matching_len(geoms.len(), &ctx.field, ctx.by)?;
126
127 let dedup = ctx.requires_dedup();
128
129 match ctx.by {
130 Some(by) => {
131 let (groups, groups_idx) = group_keys(by);
132 let mut writers = Vec::with_capacity(groups.len());
133
134 groups
135 .into_par_iter()
136 .zip(groups_idx.into_par_iter())
137 .map(|(name, idxs)| {
138 let mut writer = SparseArrayWriter::new(name);
139
140 dispatch!(ctx.all_touched, dedup, geoms, &ctx, &mut writer, idxs.iter().copied());
141
142 writer
143 })
144 .collect_into_vec(&mut writers);
145
146 Ok(writers.finish(ctx))
147 }
148 None => {
149 let mut writer = SparseArrayWriter::new(String::from("band_1"));
150
151 dispatch!(ctx.all_touched, dedup, geoms, &ctx, &mut writer, 0..geoms.len());
152
153 Ok(writer.finish(ctx))
154 }
155 }
156 }
157}
158
159#[cfg_attr(feature = "hotpath", hotpath::measure)]
162fn process<N, W, S, I>(geoms: &[Geometry<f64>], ctx: &RasterizeContext<N>, writer: &mut W, indices: I)
163where
164 N: RasterDtype,
165 W: PixelWriter<N>,
166 S: LineBurnStrategy,
167 I: Iterator<Item = usize>,
168{
169 match &ctx.field {
170 FieldSource::Scalar(s) => {
171 for i in indices {
172 geoms[i].burn::<S>(&ctx.raster_info, *s, writer, ctx.background);
173 }
174 }
175 FieldSource::Array(arr) => {
176 for i in indices {
177 geoms[i].burn::<S>(&ctx.raster_info, arr[i], writer, ctx.background);
178 }
179 }
180 #[cfg(feature = "polars")]
181 FieldSource::Column(col) => {
182 let ca = col.as_materialized_series().unpack::<N::ChunkedArrayType>().unwrap();
183 if let Ok(slice) = ca.cont_slice() {
184 for i in indices {
185 geoms[i].burn::<S>(&ctx.raster_info, slice[i], writer, ctx.background);
186 }
187 } else {
188 for i in indices {
189 if let Some(fv) = ca.get(i) {
190 geoms[i].burn::<S>(&ctx.raster_info, fv, writer, ctx.background);
191 }
192 }
193 }
194 }
195 }
196}
197
198fn group_keys(by: &[String]) -> (Vec<String>, Vec<Vec<usize>>) {
200 let mut groups: BTreeMap<&String, Vec<usize>> = BTreeMap::new();
201 for (i, key) in by.iter().enumerate() {
202 groups.entry(key).or_default().push(i);
203 }
204 groups.into_iter().map(|(k, idxs)| (k.clone(), idxs)).unzip()
205}
206
207fn assert_matching_len<N>(n_geoms: usize, field: &FieldSource<N>, by: Option<&[String]>) -> RusterizeResult<()> {
209 let field_len = match field {
210 FieldSource::Array(arr) => Some(arr.len()),
211 #[cfg(feature = "polars")]
212 FieldSource::Column(col) => Some(col.len()),
213 FieldSource::Scalar(_) => None,
214 };
215
216 if let Some(field_len) = field_len
217 && field_len != n_geoms
218 {
219 return Err(RusterizeError::ValueError("Geometry and field lengths must match"));
220 }
221
222 if let Some(by) = by
223 && by.len() != n_geoms
224 {
225 return Err(RusterizeError::ValueError("Geometry and by lengths must match"));
226 }
227
228 Ok(())
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::{geo::raster::RasterInfo, rasterization::pixel_functions::PixelFunction};
235 use geo::{Geometry, LineString, Polygon};
236
237 fn raster_4x4() -> RasterInfo {
238 RasterInfo {
239 ncols: 4,
240 nrows: 4,
241 xmin: 0.0,
242 xmax: 4.0,
243 ymin: 0.0,
244 ymax: 4.0,
245 xres: 1.0,
246 yres: 1.0,
247 epsg: None,
248 }
249 }
250
251 #[test]
252 fn dense_burns_a_polygon() {
253 let poly = Polygon::new(
254 LineString::from(vec![(0.5, 0.5), (3.5, 0.5), (3.5, 3.5), (0.5, 3.5), (0.5, 0.5)]),
255 vec![],
256 );
257 let geoms = vec![Geometry::Polygon(poly)];
258 let ctx = RasterizeContext {
259 raster_info: raster_4x4(),
260 field: FieldSource::Scalar(1.0_f64),
261 by: None,
262 pixel_fn: PixelFunction::Last,
263 background: 0.0,
264 all_touched: false,
265 };
266
267 let out: DenseArray<f64> = geoms.rasterize(ctx).unwrap();
268 let (raster, _, _) = out.into_parts();
269 assert_eq!(raster.shape(), &[1, 4, 4]);
270 assert!(
271 raster.iter().any(|&v| v == 1.0),
272 "polygon should burn at least one cell"
273 );
274 }
275
276 #[test]
277 fn multiband_burns_only_its_group() {
278 use geo::Point;
279 use ndarray::Array1;
280 let geoms = vec![
281 Geometry::Point(Point::new(0.5, 0.5)),
282 Geometry::Point(Point::new(3.5, 3.5)),
283 ];
284 let by = [String::from("a"), String::from("b")];
285 let vals = Array1::from(vec![1.0_f64, 2.0]);
286 let ctx = RasterizeContext {
287 raster_info: raster_4x4(),
288 field: FieldSource::Array(vals.view()),
289 by: Some(by.as_ref()),
290 pixel_fn: PixelFunction::Last,
291 background: 0.0,
292 all_touched: false,
293 };
294
295 let out: DenseArray<f64> = geoms.rasterize(ctx).unwrap();
296 let (raster, _, _) = out.into_parts();
297 assert_eq!(raster.shape(), &[2, 4, 4]);
298
299 for band in raster.outer_iter() {
300 let has1 = band.iter().any(|&v| v == 1.0);
301 let has2 = band.iter().any(|&v| v == 2.0);
302 assert!(has1 ^ has2, "a band burned geometries outside its group");
303 }
304 }
305
306 #[test]
307 fn group_keys_groups_and_names() {
308 let by = [String::from("b"), String::from("a"), String::from("b")];
309 let (names, idx) = group_keys(&by);
310 let mut pairs: Vec<(String, Vec<usize>)> = names.into_iter().zip(idx).collect();
311 pairs.sort();
312 assert_eq!(pairs, vec![("a".to_string(), vec![1]), ("b".to_string(), vec![0, 2])]);
313 }
314}