1use arrow_array::StructArray;
18use datafusion_common::Result;
19use fastrand::Rng;
20use sedona_raster::array::RasterStructArray;
21use sedona_raster::builder::RasterBuilder;
22use sedona_raster::traits::{BandMetadata, RasterMetadata, RasterRef};
23use sedona_schema::raster::{BandDataType, StorageType};
24
25pub fn generate_test_rasters(
28 count: usize,
29 null_raster_index: Option<usize>,
30) -> Result<StructArray> {
31 let mut builder = RasterBuilder::new(count);
32 for i in 0..count {
33 if matches!(null_raster_index, Some(index) if index == i) {
36 builder.append_null()?;
37 continue;
38 }
39
40 let raster_metadata = RasterMetadata {
41 width: i as u64 + 1,
42 height: i as u64 + 2,
43 upperleft_x: i as f64 + 1.0,
44 upperleft_y: i as f64 + 2.0,
45 scale_x: i as f64 * 0.1,
46 scale_y: i as f64 * 0.2,
47 skew_x: i as f64 * 0.3,
48 skew_y: i as f64 * 0.4,
49 };
50 builder.start_raster(&raster_metadata, None)?;
51 builder.start_band(BandMetadata {
52 datatype: BandDataType::UInt16,
53 nodata_value: Some(vec![0u8; 2]),
54 storage_type: StorageType::InDb,
55 outdb_url: None,
56 outdb_band_id: None,
57 })?;
58
59 let pixel_count = (i + 1) * (i + 2); let mut band_data = Vec::with_capacity(pixel_count * 2); for pixel_value in 0..pixel_count as u16 {
62 band_data.extend_from_slice(&pixel_value.to_le_bytes());
63 }
64
65 builder.band_data_writer().append_value(&band_data);
66 builder.finish_band()?;
67 builder.finish_raster()?;
68 }
69
70 Ok(builder.finish()?)
71}
72
73pub fn generate_tiled_rasters(
79 tile_size: (usize, usize),
80 number_of_tiles: (usize, usize),
81 data_type: BandDataType,
82 seed: Option<u64>,
83) -> Result<StructArray> {
84 let mut rng = match seed {
85 Some(s) => Rng::with_seed(s),
86 None => Rng::new(),
87 };
88 let (tile_width, tile_height) = tile_size;
89 let (x_tiles, y_tiles) = number_of_tiles;
90 let mut raster_builder = RasterBuilder::new(x_tiles * y_tiles);
91 let band_count = 3;
92
93 for tile_y in 0..y_tiles {
94 for tile_x in 0..x_tiles {
95 let origin_x = (tile_x * tile_width) as f64;
96 let origin_y = (tile_y * tile_height) as f64;
97
98 let raster_metadata = RasterMetadata {
99 width: tile_width as u64,
100 height: tile_height as u64,
101 upperleft_x: origin_x,
102 upperleft_y: origin_y,
103 scale_x: 1.0,
104 scale_y: 1.0,
105 skew_x: 0.0,
106 skew_y: 0.0,
107 };
108
109 raster_builder.start_raster(&raster_metadata, None)?;
110
111 for _ in 0..band_count {
112 let nodata_value = get_nodata_value_for_type(&data_type);
114
115 let band_metadata = BandMetadata {
116 nodata_value: nodata_value.clone(),
117 storage_type: StorageType::InDb,
118 datatype: data_type.clone(),
119 outdb_url: None,
120 outdb_band_id: None,
121 };
122
123 raster_builder.start_band(band_metadata)?;
124
125 let pixel_count = tile_width * tile_height;
126
127 let corner_position =
129 get_corner_position(tile_x, tile_y, x_tiles, y_tiles, tile_width, tile_height);
130 let band_data = generate_random_band_data(
131 pixel_count,
132 &data_type,
133 nodata_value.as_deref(),
134 corner_position,
135 &mut rng,
136 );
137
138 raster_builder.band_data_writer().append_value(&band_data);
139 raster_builder.finish_band()?;
140 }
141
142 raster_builder.finish_raster()?;
143 }
144 }
145
146 Ok(raster_builder.finish()?)
147}
148
149fn get_corner_position(
152 tile_x: usize,
153 tile_y: usize,
154 x_tiles: usize,
155 y_tiles: usize,
156 tile_width: usize,
157 tile_height: usize,
158) -> Option<usize> {
159 if tile_x == 0 && tile_y == 0 {
161 return Some(0);
162 }
163 if tile_x == x_tiles - 1 && tile_y == 0 {
165 return Some(tile_width - 1);
166 }
167 if tile_x == 0 && tile_y == y_tiles - 1 {
169 return Some((tile_height - 1) * tile_width);
170 }
171 if tile_x == x_tiles - 1 && tile_y == y_tiles - 1 {
173 return Some(tile_height * tile_width - 1);
174 }
175 None
176}
177
178fn generate_random_band_data(
179 pixel_count: usize,
180 data_type: &BandDataType,
181 nodata_bytes: Option<&[u8]>,
182 corner_position: Option<usize>,
183 rng: &mut Rng,
184) -> Vec<u8> {
185 match data_type {
186 BandDataType::UInt8 => {
187 let mut data: Vec<u8> = (0..pixel_count).map(|_| rng.u8(..)).collect();
188 if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
190 if !nodata.is_empty() && pos < data.len() {
191 data[pos] = nodata[0];
192 }
193 }
194 data
195 }
196 BandDataType::UInt16 => {
197 let mut data = Vec::with_capacity(pixel_count * 2);
198 for _ in 0..pixel_count {
199 data.extend_from_slice(&rng.u16(..).to_ne_bytes());
200 }
201 if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
203 if nodata.len() >= 2 && pos * 2 + 2 <= data.len() {
204 data[pos * 2..(pos * 2) + 2].copy_from_slice(&nodata[0..2]);
205 }
206 }
207 data
208 }
209 BandDataType::Int16 => {
210 let mut data = Vec::with_capacity(pixel_count * 2);
211 for _ in 0..pixel_count {
212 data.extend_from_slice(&rng.i16(..).to_ne_bytes());
213 }
214 if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
216 if nodata.len() >= 2 && pos * 2 + 2 <= data.len() {
217 data[pos * 2..(pos * 2) + 2].copy_from_slice(&nodata[0..2]);
218 }
219 }
220 data
221 }
222 BandDataType::UInt32 => {
223 let mut data = Vec::with_capacity(pixel_count * 4);
224 for _ in 0..pixel_count {
225 data.extend_from_slice(&rng.u32(..).to_ne_bytes());
226 }
227 if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
229 if nodata.len() >= 4 && pos * 4 + 4 <= data.len() {
230 data[pos * 4..(pos * 4) + 4].copy_from_slice(&nodata[0..4]);
231 }
232 }
233 data
234 }
235 BandDataType::Int32 => {
236 let mut data = Vec::with_capacity(pixel_count * 4);
237 for _ in 0..pixel_count {
238 data.extend_from_slice(&rng.i32(..).to_ne_bytes());
239 }
240 if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
242 if nodata.len() >= 4 && pos * 4 + 4 <= data.len() {
243 data[pos * 4..(pos * 4) + 4].copy_from_slice(&nodata[0..4]);
244 }
245 }
246 data
247 }
248 BandDataType::Float32 => {
249 let mut data = Vec::with_capacity(pixel_count * 4);
250 for _ in 0..pixel_count {
251 data.extend_from_slice(&rng.f32().to_ne_bytes());
252 }
253 if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
255 if nodata.len() >= 4 && pos * 4 + 4 <= data.len() {
256 data[pos * 4..(pos * 4) + 4].copy_from_slice(&nodata[0..4]);
257 }
258 }
259 data
260 }
261 BandDataType::Float64 => {
262 let mut data = Vec::with_capacity(pixel_count * 8);
263 for _ in 0..pixel_count {
264 data.extend_from_slice(&rng.f64().to_ne_bytes());
265 }
266 if let (Some(nodata), Some(pos)) = (nodata_bytes, corner_position) {
268 if nodata.len() >= 8 && pos * 8 + 8 <= data.len() {
269 data[pos * 8..(pos * 8) + 8].copy_from_slice(&nodata[0..8]);
270 }
271 }
272 data
273 }
274 }
275}
276
277fn get_nodata_value_for_type(data_type: &BandDataType) -> Option<Vec<u8>> {
278 match data_type {
279 BandDataType::UInt8 => Some(vec![255u8]),
280 BandDataType::UInt16 => Some(u16::MAX.to_ne_bytes().to_vec()),
281 BandDataType::Int16 => Some(i16::MIN.to_ne_bytes().to_vec()),
282 BandDataType::UInt32 => Some(u32::MAX.to_ne_bytes().to_vec()),
283 BandDataType::Int32 => Some(i32::MIN.to_ne_bytes().to_vec()),
284 BandDataType::Float32 => Some(f32::NAN.to_ne_bytes().to_vec()),
285 BandDataType::Float64 => Some(f64::NAN.to_ne_bytes().to_vec()),
286 }
287}
288
289pub fn assert_raster_arrays_equal(
291 raster_array1: &RasterStructArray,
292 raster_array2: &RasterStructArray,
293) {
294 assert_eq!(
295 raster_array1.len(),
296 raster_array2.len(),
297 "Raster array lengths do not match"
298 );
299
300 for i in 0..raster_array1.len() {
301 let raster1 = raster_array1.get(i).unwrap();
302 let raster2 = raster_array2.get(i).unwrap();
303 assert_raster_equal(&raster1, &raster2);
304 }
305}
306
307pub fn assert_raster_equal(raster1: &impl RasterRef, raster2: &impl RasterRef) {
309 let meta1 = raster1.metadata();
311 let meta2 = raster2.metadata();
312 assert_eq!(meta1.width(), meta2.width(), "Raster widths do not match");
313 assert_eq!(
314 meta1.height(),
315 meta2.height(),
316 "Raster heights do not match"
317 );
318 assert_eq!(
319 meta1.upper_left_x(),
320 meta2.upper_left_x(),
321 "Raster upper left x does not match"
322 );
323 assert_eq!(
324 meta1.upper_left_y(),
325 meta2.upper_left_y(),
326 "Raster upper left y does not match"
327 );
328 assert_eq!(
329 meta1.scale_x(),
330 meta2.scale_x(),
331 "Raster scale x does not match"
332 );
333 assert_eq!(
334 meta1.scale_y(),
335 meta2.scale_y(),
336 "Raster scale y does not match"
337 );
338 assert_eq!(
339 meta1.skew_x(),
340 meta2.skew_x(),
341 "Raster skew x does not match"
342 );
343 assert_eq!(
344 meta1.skew_y(),
345 meta2.skew_y(),
346 "Raster skew y does not match"
347 );
348
349 let bands1 = raster1.bands();
351 let bands2 = raster2.bands();
352 assert_eq!(bands1.len(), bands2.len(), "Number of bands do not match");
353
354 for band_index in 0..bands1.len() {
355 let band1 = bands1.band(band_index + 1).unwrap();
356 let band2 = bands2.band(band_index + 1).unwrap();
357
358 let band_meta1 = band1.metadata();
359 let band_meta2 = band2.metadata();
360 assert_eq!(
361 band_meta1.data_type(),
362 band_meta2.data_type(),
363 "Band data types do not match"
364 );
365 assert_eq!(
366 band_meta1.nodata_value(),
367 band_meta2.nodata_value(),
368 "Band nodata values do not match"
369 );
370 assert_eq!(
371 band_meta1.storage_type(),
372 band_meta2.storage_type(),
373 "Band storage types do not match"
374 );
375 assert_eq!(
376 band_meta1.outdb_url(),
377 band_meta2.outdb_url(),
378 "Band outdb URLs do not match"
379 );
380 assert_eq!(
381 band_meta1.outdb_band_id(),
382 band_meta2.outdb_band_id(),
383 "Band outdb band IDs do not match"
384 );
385
386 assert_eq!(band1.data(), band2.data(), "Band data does not match");
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use sedona_raster::array::RasterStructArray;
394 use sedona_raster::traits::RasterRef;
395
396 #[test]
397 fn test_generate_test_rasters() {
398 let count = 5;
399 let struct_array = generate_test_rasters(count, None).unwrap();
400 let raster_array = RasterStructArray::new(&struct_array);
401 assert_eq!(raster_array.len(), count);
402
403 for i in 0..count {
404 let raster = raster_array.get(i).unwrap();
405 let metadata = raster.metadata();
406 assert_eq!(metadata.width(), i as u64 + 1);
407 assert_eq!(metadata.height(), i as u64 + 2);
408 assert_eq!(metadata.upper_left_x(), i as f64 + 1.0);
409 assert_eq!(metadata.upper_left_y(), i as f64 + 2.0);
410 assert_eq!(metadata.scale_x(), (i as f64) * 0.1);
411 assert_eq!(metadata.scale_y(), (i as f64) * 0.2);
412 assert_eq!(metadata.skew_x(), (i as f64) * 0.3);
413 assert_eq!(metadata.skew_y(), (i as f64) * 0.4);
414
415 let bands = raster.bands();
416 let band = bands.band(1).unwrap();
417 let band_metadata = band.metadata();
418 assert_eq!(band_metadata.data_type(), BandDataType::UInt16);
419 assert_eq!(band_metadata.nodata_value(), Some(&[0u8, 0u8][..]));
420 assert_eq!(band_metadata.storage_type(), StorageType::InDb);
421 assert_eq!(band_metadata.outdb_url(), None);
422 assert_eq!(band_metadata.outdb_band_id(), None);
423
424 let band_data = band.data();
425 let expected_pixel_count = (i + 1) * (i + 2); let mut actual_pixel_values = Vec::new();
429 for chunk in band_data.chunks_exact(2) {
430 let value = u16::from_le_bytes([chunk[0], chunk[1]]);
431 actual_pixel_values.push(value);
432 }
433 let expected_pixel_values: Vec<u16> = (0..expected_pixel_count as u16).collect();
434 assert_eq!(actual_pixel_values, expected_pixel_values);
435 }
436 }
437
438 #[test]
439 fn test_generate_tiled_rasters() {
440 let tile_size = (64, 64);
441 let number_of_tiles = (4, 4);
442 let data_type = BandDataType::UInt8;
443 let struct_array =
444 generate_tiled_rasters(tile_size, number_of_tiles, data_type, Some(43)).unwrap();
445 let raster_array = RasterStructArray::new(&struct_array);
446 assert_eq!(raster_array.len(), 16); for i in 0..16 {
448 let raster = raster_array.get(i).unwrap();
449 let metadata = raster.metadata();
450 assert_eq!(metadata.width(), 64);
451 assert_eq!(metadata.height(), 64);
452 assert_eq!(metadata.upper_left_x(), ((i % 4) * 64) as f64);
453 assert_eq!(metadata.upper_left_y(), ((i / 4) * 64) as f64);
454 let bands = raster.bands();
455 assert_eq!(bands.len(), 3);
456 for band_index in 0..3 {
457 let band = bands.band(band_index + 1).unwrap();
458 let band_metadata = band.metadata();
459 assert_eq!(band_metadata.data_type(), BandDataType::UInt8);
460 assert_eq!(band_metadata.storage_type(), StorageType::InDb);
461 let band_data = band.data();
462 assert_eq!(band_data.len(), 64 * 64); }
464 }
465 }
466
467 #[test]
468 fn test_raster_arrays_equal() {
469 let raster_array1 = generate_test_rasters(3, None).unwrap();
470 let raster_struct_array1 = RasterStructArray::new(&raster_array1);
471 assert_raster_arrays_equal(&raster_struct_array1, &raster_struct_array1);
473 }
474
475 #[test]
476 #[should_panic = "Raster array lengths do not match"]
477 fn test_raster_arrays_not_equal() {
478 let raster_array1 = generate_test_rasters(3, None).unwrap();
479 let raster_struct_array1 = RasterStructArray::new(&raster_array1);
480
481 let raster_array2 = generate_test_rasters(4, None).unwrap();
483 let raster_struct_array2 = RasterStructArray::new(&raster_array2);
484 assert_raster_arrays_equal(&raster_struct_array1, &raster_struct_array2);
485 }
486
487 #[test]
488 fn test_raster_equal() {
489 let raster_array1 =
490 generate_tiled_rasters((256, 256), (1, 1), BandDataType::UInt8, Some(43)).unwrap();
491 let raster1 = RasterStructArray::new(&raster_array1).get(0).unwrap();
492
493 assert_raster_equal(&raster1, &raster1);
495 }
496
497 #[test]
498 #[should_panic = "Band data does not match"]
499 fn test_raster_different_band_data() {
500 let raster_array1 =
501 generate_tiled_rasters((128, 128), (1, 1), BandDataType::UInt8, Some(43)).unwrap();
502 let raster_array2 =
503 generate_tiled_rasters((128, 128), (1, 1), BandDataType::UInt8, Some(47)).unwrap();
504
505 let raster1 = RasterStructArray::new(&raster_array1).get(0).unwrap();
506 let raster2 = RasterStructArray::new(&raster_array2).get(0).unwrap();
507 assert_raster_equal(&raster1, &raster2);
508 }
509
510 #[test]
511 #[should_panic = "Raster upper left x does not match"]
512 fn test_raster_different_metadata() {
513 let raster_array =
514 generate_tiled_rasters((128, 128), (2, 1), BandDataType::UInt8, Some(43)).unwrap();
515 let raster1 = RasterStructArray::new(&raster_array).get(0).unwrap();
516 let raster2 = RasterStructArray::new(&raster_array).get(1).unwrap();
517 assert_raster_equal(&raster1, &raster2);
518 }
519}