sedona_testing/
rasters.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17use arrow_array::StructArray;
18use datafusion_common::Result;
19use fastrand::Rng;
20use sedona_raster::array::RasterStructArray;
21use sedona_raster::builder::RasterBuilder;
22use sedona_raster::traits::{BandMetadata, RasterMetadata, RasterRef};
23use sedona_schema::raster::{BandDataType, StorageType};
24
25/// Generate a StructArray of rasters with sequentially increasing dimensions and pixel values
26/// These tiny rasters are to provide fast, easy and predictable test data for unit tests.
27pub fn generate_test_rasters(
28    count: usize,
29    null_raster_index: Option<usize>,
30) -> Result<StructArray> {
31    let mut builder = RasterBuilder::new(count);
32    for i in 0..count {
33        // If a null raster index is specified and that matches the current index,
34        // append a null raster
35        if matches!(null_raster_index, Some(index) if index == i) {
36            builder.append_null()?;
37            continue;
38        }
39
40        let raster_metadata = RasterMetadata {
41            width: i as u64 + 1,
42            height: i as u64 + 2,
43            upperleft_x: i as f64 + 1.0,
44            upperleft_y: i as f64 + 2.0,
45            scale_x: i as f64 * 0.1,
46            scale_y: i as f64 * 0.2,
47            skew_x: i as f64 * 0.3,
48            skew_y: i as f64 * 0.4,
49        };
50        builder.start_raster(&raster_metadata, None)?;
51        builder.start_band(BandMetadata {
52            datatype: BandDataType::UInt16,
53            nodata_value: Some(vec![0u8; 2]),
54            storage_type: StorageType::InDb,
55            outdb_url: None,
56            outdb_band_id: None,
57        })?;
58
59        let pixel_count = (i + 1) * (i + 2); // width * height
60        let mut band_data = Vec::with_capacity(pixel_count * 2); // 2 bytes per u16
61        for pixel_value in 0..pixel_count as u16 {
62            band_data.extend_from_slice(&pixel_value.to_le_bytes());
63        }
64
65        builder.band_data_writer().append_value(&band_data);
66        builder.finish_band()?;
67        builder.finish_raster()?;
68    }
69
70    Ok(builder.finish()?)
71}
72
73/// Generates a set of tiled rasters arranged in a grid
74/// - Each raster tile has specified dimensions and random pixel values
75/// - Each raster has 3 bands which can be interpreted as RGB values
76///   and the result can be visualized as a mosaic of tiles.
77/// - There are nodata values at the 4 corners of the overall mosaic.
78pub fn generate_tiled_rasters(
79    tile_size: (usize, usize),
80    number_of_tiles: (usize, usize),
81    data_type: BandDataType,
82    seed: Option<u64>,
83) -> Result<StructArray> {
84    let mut rng = match seed {
85        Some(s) => Rng::with_seed(s),
86        None => Rng::new(),
87    };
88    let (tile_width, tile_height) = tile_size;
89    let (x_tiles, y_tiles) = number_of_tiles;
90    let mut raster_builder = RasterBuilder::new(x_tiles * y_tiles);
91    let band_count = 3;
92
93    for tile_y in 0..y_tiles {
94        for tile_x in 0..x_tiles {
95            let origin_x = (tile_x * tile_width) as f64;
96            let origin_y = (tile_y * tile_height) as f64;
97
98            let raster_metadata = RasterMetadata {
99                width: tile_width as u64,
100                height: tile_height as u64,
101                upperleft_x: origin_x,
102                upperleft_y: origin_y,
103                scale_x: 1.0,
104                scale_y: 1.0,
105                skew_x: 0.0,
106                skew_y: 0.0,
107            };
108
109            raster_builder.start_raster(&raster_metadata, None)?;
110
111            for _ in 0..band_count {
112                // Set a nodata value appropriate for the data type
113                let nodata_value = get_nodata_value_for_type(&data_type);
114
115                let band_metadata = BandMetadata {
116                    nodata_value: nodata_value.clone(),
117                    storage_type: StorageType::InDb,
118                    datatype: data_type.clone(),
119                    outdb_url: None,
120                    outdb_band_id: None,
121                };
122
123                raster_builder.start_band(band_metadata)?;
124
125                let pixel_count = tile_width * tile_height;
126
127                // Determine which corner position (if any) should have nodata in this tile
128                let corner_position =
129                    get_corner_position(tile_x, tile_y, x_tiles, y_tiles, tile_width, tile_height);
130                let band_data = generate_random_band_data(
131                    pixel_count,
132                    &data_type,
133                    nodata_value.as_deref(),
134                    corner_position,
135                    &mut rng,
136                );
137
138                raster_builder.band_data_writer().append_value(&band_data);
139                raster_builder.finish_band()?;
140            }
141
142            raster_builder.finish_raster()?;
143        }
144    }
145
146    Ok(raster_builder.finish()?)
147}
148
149/// Determine if this tile contains a corner of the overall grid and return its position
150/// Returns Some(position) if this tile contains a corner, None otherwise
151fn get_corner_position(
152    tile_x: usize,
153    tile_y: usize,
154    x_tiles: usize,
155    y_tiles: usize,
156    tile_width: usize,
157    tile_height: usize,
158) -> Option<usize> {
159    // Top-left corner (tile 0,0, pixel 0)
160    if tile_x == 0 && tile_y == 0 {
161        return Some(0);
162    }
163    // Top-right corner (tile x_tiles-1, 0, pixel tile_width-1)
164    if tile_x == x_tiles - 1 && tile_y == 0 {
165        return Some(tile_width - 1);
166    }
167    // Bottom-left corner (tile 0, y_tiles-1, pixel (tile_height-1)*tile_width)
168    if tile_x == 0 && tile_y == y_tiles - 1 {
169        return Some((tile_height - 1) * tile_width);
170    }
171    // Bottom-right corner (tile x_tiles-1, y_tiles-1, pixel tile_height*tile_width-1)
172    if tile_x == x_tiles - 1 && tile_y == y_tiles - 1 {
173        return Some(tile_height * tile_width - 1);
174    }
175    None
176}
177
178fn generate_random_band_data(
179    pixel_count: usize,
180    data_type: &BandDataType,
181    nodata_bytes: Option<&[u8]>,
182    corner_position: Option<usize>,
183    rng: &mut Rng,
184) -> Vec<u8> {
185    match data_type {
186        BandDataType::UInt8 => {
187            let mut data: Vec<u8> = (0..pixel_count).map(|_| rng.u8(..)).collect();
188            // Set corner pixel to nodata value if this tile contains a corner
189            if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
190                if !nodata.is_empty() && pos < data.len() {
191                    data[pos] = nodata[0];
192                }
193            }
194            data
195        }
196        BandDataType::UInt16 => {
197            let mut data = Vec::with_capacity(pixel_count * 2);
198            for _ in 0..pixel_count {
199                data.extend_from_slice(&rng.u16(..).to_ne_bytes());
200            }
201            // Set corner pixel to nodata value if this tile contains a corner
202            if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
203                if nodata.len() >= 2 && pos * 2 + 2 <= data.len() {
204                    data[pos * 2..(pos * 2) + 2].copy_from_slice(&nodata[0..2]);
205                }
206            }
207            data
208        }
209        BandDataType::Int16 => {
210            let mut data = Vec::with_capacity(pixel_count * 2);
211            for _ in 0..pixel_count {
212                data.extend_from_slice(&rng.i16(..).to_ne_bytes());
213            }
214            // Set corner pixel to nodata value if this tile contains a corner
215            if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
216                if nodata.len() >= 2 && pos * 2 + 2 <= data.len() {
217                    data[pos * 2..(pos * 2) + 2].copy_from_slice(&nodata[0..2]);
218                }
219            }
220            data
221        }
222        BandDataType::UInt32 => {
223            let mut data = Vec::with_capacity(pixel_count * 4);
224            for _ in 0..pixel_count {
225                data.extend_from_slice(&rng.u32(..).to_ne_bytes());
226            }
227            // Set corner pixel to nodata value if this tile contains a corner
228            if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
229                if nodata.len() >= 4 && pos * 4 + 4 <= data.len() {
230                    data[pos * 4..(pos * 4) + 4].copy_from_slice(&nodata[0..4]);
231                }
232            }
233            data
234        }
235        BandDataType::Int32 => {
236            let mut data = Vec::with_capacity(pixel_count * 4);
237            for _ in 0..pixel_count {
238                data.extend_from_slice(&rng.i32(..).to_ne_bytes());
239            }
240            // Set corner pixel to nodata value if this tile contains a corner
241            if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
242                if nodata.len() >= 4 && pos * 4 + 4 <= data.len() {
243                    data[pos * 4..(pos * 4) + 4].copy_from_slice(&nodata[0..4]);
244                }
245            }
246            data
247        }
248        BandDataType::Float32 => {
249            let mut data = Vec::with_capacity(pixel_count * 4);
250            for _ in 0..pixel_count {
251                data.extend_from_slice(&rng.f32().to_ne_bytes());
252            }
253            // Set corner pixel to nodata value if this tile contains a corner
254            if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
255                if nodata.len() >= 4 && pos * 4 + 4 <= data.len() {
256                    data[pos * 4..(pos * 4) + 4].copy_from_slice(&nodata[0..4]);
257                }
258            }
259            data
260        }
261        BandDataType::Float64 => {
262            let mut data = Vec::with_capacity(pixel_count * 8);
263            for _ in 0..pixel_count {
264                data.extend_from_slice(&rng.f64().to_ne_bytes());
265            }
266            // Set corner pixel to nodata value if this tile contains a corner
267            if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
268                if nodata.len() >= 8 && pos * 8 + 8 <= data.len() {
269                    data[pos * 8..(pos * 8) + 8].copy_from_slice(&nodata[0..8]);
270                }
271            }
272            data
273        }
274    }
275}
276
277fn get_nodata_value_for_type(data_type: &BandDataType) -> Option<Vec<u8>> {
278    match data_type {
279        BandDataType::UInt8 => Some(vec![255u8]),
280        BandDataType::UInt16 => Some(u16::MAX.to_ne_bytes().to_vec()),
281        BandDataType::Int16 => Some(i16::MIN.to_ne_bytes().to_vec()),
282        BandDataType::UInt32 => Some(u32::MAX.to_ne_bytes().to_vec()),
283        BandDataType::Int32 => Some(i32::MIN.to_ne_bytes().to_vec()),
284        BandDataType::Float32 => Some(f32::NAN.to_ne_bytes().to_vec()),
285        BandDataType::Float64 => Some(f64::NAN.to_ne_bytes().to_vec()),
286    }
287}
288
289/// Compare two RasterStructArrays for equality
290pub fn assert_raster_arrays_equal(
291    raster_array1: &RasterStructArray,
292    raster_array2: &RasterStructArray,
293) {
294    assert_eq!(
295        raster_array1.len(),
296        raster_array2.len(),
297        "Raster array lengths do not match"
298    );
299
300    for i in 0..raster_array1.len() {
301        let raster1 = raster_array1.get(i).unwrap();
302        let raster2 = raster_array2.get(i).unwrap();
303        assert_raster_equal(&raster1, &raster2);
304    }
305}
306
307/// Compare two rasters for equality
308pub fn assert_raster_equal(raster1: &impl RasterRef, raster2: &impl RasterRef) {
309    // Compare metadata
310    let meta1 = raster1.metadata();
311    let meta2 = raster2.metadata();
312    assert_eq!(meta1.width(), meta2.width(), "Raster widths do not match");
313    assert_eq!(
314        meta1.height(),
315        meta2.height(),
316        "Raster heights do not match"
317    );
318    assert_eq!(
319        meta1.upper_left_x(),
320        meta2.upper_left_x(),
321        "Raster upper left x does not match"
322    );
323    assert_eq!(
324        meta1.upper_left_y(),
325        meta2.upper_left_y(),
326        "Raster upper left y does not match"
327    );
328    assert_eq!(
329        meta1.scale_x(),
330        meta2.scale_x(),
331        "Raster scale x does not match"
332    );
333    assert_eq!(
334        meta1.scale_y(),
335        meta2.scale_y(),
336        "Raster scale y does not match"
337    );
338    assert_eq!(
339        meta1.skew_x(),
340        meta2.skew_x(),
341        "Raster skew x does not match"
342    );
343    assert_eq!(
344        meta1.skew_y(),
345        meta2.skew_y(),
346        "Raster skew y does not match"
347    );
348
349    // Compare bands
350    let bands1 = raster1.bands();
351    let bands2 = raster2.bands();
352    assert_eq!(bands1.len(), bands2.len(), "Number of bands do not match");
353
354    for band_index in 0..bands1.len() {
355        let band1 = bands1.band(band_index + 1).unwrap();
356        let band2 = bands2.band(band_index + 1).unwrap();
357
358        let band_meta1 = band1.metadata();
359        let band_meta2 = band2.metadata();
360        assert_eq!(
361            band_meta1.data_type(),
362            band_meta2.data_type(),
363            "Band data types do not match"
364        );
365        assert_eq!(
366            band_meta1.nodata_value(),
367            band_meta2.nodata_value(),
368            "Band nodata values do not match"
369        );
370        assert_eq!(
371            band_meta1.storage_type(),
372            band_meta2.storage_type(),
373            "Band storage types do not match"
374        );
375        assert_eq!(
376            band_meta1.outdb_url(),
377            band_meta2.outdb_url(),
378            "Band outdb URLs do not match"
379        );
380        assert_eq!(
381            band_meta1.outdb_band_id(),
382            band_meta2.outdb_band_id(),
383            "Band outdb band IDs do not match"
384        );
385
386        assert_eq!(band1.data(), band2.data(), "Band data does not match");
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use sedona_raster::array::RasterStructArray;
394    use sedona_raster::traits::RasterRef;
395
396    #[test]
397    fn test_generate_test_rasters() {
398        let count = 5;
399        let struct_array = generate_test_rasters(count, None).unwrap();
400        let raster_array = RasterStructArray::new(&struct_array);
401        assert_eq!(raster_array.len(), count);
402
403        for i in 0..count {
404            let raster = raster_array.get(i).unwrap();
405            let metadata = raster.metadata();
406            assert_eq!(metadata.width(), i as u64 + 1);
407            assert_eq!(metadata.height(), i as u64 + 2);
408            assert_eq!(metadata.upper_left_x(), i as f64 + 1.0);
409            assert_eq!(metadata.upper_left_y(), i as f64 + 2.0);
410            assert_eq!(metadata.scale_x(), (i as f64) * 0.1);
411            assert_eq!(metadata.scale_y(), (i as f64) * 0.2);
412            assert_eq!(metadata.skew_x(), (i as f64) * 0.3);
413            assert_eq!(metadata.skew_y(), (i as f64) * 0.4);
414
415            let bands = raster.bands();
416            let band = bands.band(1).unwrap();
417            let band_metadata = band.metadata();
418            assert_eq!(band_metadata.data_type(), BandDataType::UInt16);
419            assert_eq!(band_metadata.nodata_value(), Some(&[0u8, 0u8][..]));
420            assert_eq!(band_metadata.storage_type(), StorageType::InDb);
421            assert_eq!(band_metadata.outdb_url(), None);
422            assert_eq!(band_metadata.outdb_band_id(), None);
423
424            let band_data = band.data();
425            let expected_pixel_count = (i + 1) * (i + 2); // width * height
426
427            // Convert raw bytes back to u16 values for comparison
428            let mut actual_pixel_values = Vec::new();
429            for chunk in band_data.chunks_exact(2) {
430                let value = u16::from_le_bytes([chunk[0], chunk[1]]);
431                actual_pixel_values.push(value);
432            }
433            let expected_pixel_values: Vec<u16> = (0..expected_pixel_count as u16).collect();
434            assert_eq!(actual_pixel_values, expected_pixel_values);
435        }
436    }
437
438    #[test]
439    fn test_generate_tiled_rasters() {
440        let tile_size = (64, 64);
441        let number_of_tiles = (4, 4);
442        let data_type = BandDataType::UInt8;
443        let struct_array =
444            generate_tiled_rasters(tile_size, number_of_tiles, data_type, Some(43)).unwrap();
445        let raster_array = RasterStructArray::new(&struct_array);
446        assert_eq!(raster_array.len(), 16); // 4x4 tiles
447        for i in 0..16 {
448            let raster = raster_array.get(i).unwrap();
449            let metadata = raster.metadata();
450            assert_eq!(metadata.width(), 64);
451            assert_eq!(metadata.height(), 64);
452            assert_eq!(metadata.upper_left_x(), ((i % 4) * 64) as f64);
453            assert_eq!(metadata.upper_left_y(), ((i / 4) * 64) as f64);
454            let bands = raster.bands();
455            assert_eq!(bands.len(), 3);
456            for band_index in 0..3 {
457                let band = bands.band(band_index + 1).unwrap();
458                let band_metadata = band.metadata();
459                assert_eq!(band_metadata.data_type(), BandDataType::UInt8);
460                assert_eq!(band_metadata.storage_type(), StorageType::InDb);
461                let band_data = band.data();
462                assert_eq!(band_data.len(), 64 * 64); // 4096 pixels
463            }
464        }
465    }
466
467    #[test]
468    fn test_raster_arrays_equal() {
469        let raster_array1 = generate_test_rasters(3, None).unwrap();
470        let raster_struct_array1 = RasterStructArray::new(&raster_array1);
471        // Test that identical arrays are equal
472        assert_raster_arrays_equal(&raster_struct_array1, &raster_struct_array1);
473    }
474
475    #[test]
476    #[should_panic = "Raster array lengths do not match"]
477    fn test_raster_arrays_not_equal() {
478        let raster_array1 = generate_test_rasters(3, None).unwrap();
479        let raster_struct_array1 = RasterStructArray::new(&raster_array1);
480
481        // Test that arrays with different lengths are not equal
482        let raster_array2 = generate_test_rasters(4, None).unwrap();
483        let raster_struct_array2 = RasterStructArray::new(&raster_array2);
484        assert_raster_arrays_equal(&raster_struct_array1, &raster_struct_array2);
485    }
486
487    #[test]
488    fn test_raster_equal() {
489        let raster_array1 =
490            generate_tiled_rasters((256, 256), (1, 1), BandDataType::UInt8, Some(43)).unwrap();
491        let raster1 = RasterStructArray::new(&raster_array1).get(0).unwrap();
492
493        // Assert that the rasters are equal to themselves
494        assert_raster_equal(&raster1, &raster1);
495    }
496
497    #[test]
498    #[should_panic = "Band data does not match"]
499    fn test_raster_different_band_data() {
500        let raster_array1 =
501            generate_tiled_rasters((128, 128), (1, 1), BandDataType::UInt8, Some(43)).unwrap();
502        let raster_array2 =
503            generate_tiled_rasters((128, 128), (1, 1), BandDataType::UInt8, Some(47)).unwrap();
504
505        let raster1 = RasterStructArray::new(&raster_array1).get(0).unwrap();
506        let raster2 = RasterStructArray::new(&raster_array2).get(0).unwrap();
507        assert_raster_equal(&raster1, &raster2);
508    }
509
510    #[test]
511    #[should_panic = "Raster upper left x does not match"]
512    fn test_raster_different_metadata() {
513        let raster_array =
514            generate_tiled_rasters((128, 128), (2, 1), BandDataType::UInt8, Some(43)).unwrap();
515        let raster1 = RasterStructArray::new(&raster_array).get(0).unwrap();
516        let raster2 = RasterStructArray::new(&raster_array).get(1).unwrap();
517        assert_raster_equal(&raster1, &raster2);
518    }
519}