use crate::crs::CRS;
use crate::error::{Error, Result};
use crate::raster::{GeoTransform, RasterElement};
use ndarray::{Array2, ArrayView2, ArrayViewMut2};
#[derive(Debug, Clone)]
pub struct Raster<T: RasterElement> {
data: Array2<T>,
transform: GeoTransform,
crs: Option<CRS>,
nodata: Option<T>,
}
impl<T: RasterElement> Raster<T> {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
data: Array2::zeros((rows, cols)),
transform: GeoTransform::default(),
crs: None,
nodata: None,
}
}
pub fn filled(rows: usize, cols: usize, value: T) -> Self {
Self {
data: Array2::from_elem((rows, cols), value),
transform: GeoTransform::default(),
crs: None,
nodata: None,
}
}
pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Result<Self> {
if data.len() != rows * cols {
return Err(Error::InvalidDimensions {
width: cols,
height: rows,
});
}
let array =
Array2::from_shape_vec((rows, cols), data).map_err(|e| Error::Other(e.to_string()))?;
Ok(Self {
data: array,
transform: GeoTransform::default(),
crs: None,
nodata: None,
})
}
pub fn from_array(data: Array2<T>) -> Self {
Self {
data,
transform: GeoTransform::default(),
crs: None,
nodata: None,
}
}
pub fn with_same_meta<U: RasterElement>(&self, rows: usize, cols: usize) -> Raster<U> {
Raster {
data: Array2::zeros((rows, cols)),
transform: self.transform,
crs: self.crs.clone(),
nodata: None,
}
}
pub fn like(&self, fill_value: T) -> Self {
Self {
data: Array2::from_elem(self.data.dim(), fill_value),
transform: self.transform,
crs: self.crs.clone(),
nodata: self.nodata,
}
}
pub fn rows(&self) -> usize {
self.data.nrows()
}
pub fn cols(&self) -> usize {
self.data.ncols()
}
pub fn shape(&self) -> (usize, usize) {
self.data.dim()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn get(&self, row: usize, col: usize) -> Result<T> {
self.data
.get((row, col))
.copied()
.ok_or(Error::IndexOutOfBounds {
row,
col,
rows: self.rows(),
cols: self.cols(),
})
}
pub unsafe fn get_unchecked(&self, row: usize, col: usize) -> T {
unsafe { *self.data.uget((row, col)) }
}
pub fn set(&mut self, row: usize, col: usize, value: T) -> Result<()> {
if row >= self.rows() || col >= self.cols() {
return Err(Error::IndexOutOfBounds {
row,
col,
rows: self.rows(),
cols: self.cols(),
});
}
self.data[(row, col)] = value;
Ok(())
}
pub unsafe fn set_unchecked(&mut self, row: usize, col: usize, value: T) {
unsafe {
*self.data.uget_mut((row, col)) = value;
}
}
pub fn view(&self) -> ArrayView2<'_, T> {
self.data.view()
}
pub fn view_mut(&mut self) -> ArrayViewMut2<'_, T> {
self.data.view_mut()
}
pub fn data(&self) -> &Array2<T> {
&self.data
}
pub fn data_mut(&mut self) -> &mut Array2<T> {
&mut self.data
}
pub fn into_array(self) -> Array2<T> {
self.data
}
pub fn row(&self, row: usize) -> Result<ndarray::ArrayView1<'_, T>> {
if row >= self.rows() {
return Err(Error::IndexOutOfBounds {
row,
col: 0,
rows: self.rows(),
cols: self.cols(),
});
}
Ok(self.data.row(row))
}
pub fn transform(&self) -> &GeoTransform {
&self.transform
}
pub fn set_transform(&mut self, transform: GeoTransform) {
self.transform = transform;
}
pub fn crs(&self) -> Option<&CRS> {
self.crs.as_ref()
}
pub fn set_crs(&mut self, crs: Option<CRS>) {
self.crs = crs;
}
pub fn nodata(&self) -> Option<T> {
self.nodata
}
pub fn set_nodata(&mut self, nodata: Option<T>) {
self.nodata = nodata;
}
pub fn cell_size(&self) -> f64 {
self.transform.cell_size()
}
pub fn bounds(&self) -> (f64, f64, f64, f64) {
self.transform.bounds(self.cols(), self.rows())
}
pub fn pixel_to_geo(&self, col: usize, row: usize) -> (f64, f64) {
self.transform.pixel_to_geo(col, row)
}
pub fn geo_to_pixel(&self, x: f64, y: f64) -> (f64, f64) {
self.transform.geo_to_pixel(x, y)
}
pub fn is_nodata(&self, value: T) -> bool {
value.is_nodata(self.nodata)
}
pub fn is_nodata_at(&self, row: usize, col: usize) -> Result<bool> {
let value = self.get(row, col)?;
Ok(self.is_nodata(value))
}
pub fn statistics(&self) -> RasterStatistics<T>
where
T: PartialOrd,
{
let mut min = None;
let mut max = None;
let mut sum: f64 = 0.0;
let mut count: usize = 0;
for &value in self.data.iter() {
if self.is_nodata(value) {
continue;
}
if min.is_none() || value < min.unwrap() {
min = Some(value);
}
if max.is_none() || value > max.unwrap() {
max = Some(value);
}
if let Some(v) = value.to_f64() {
sum += v;
count += 1;
}
}
let mean = if count > 0 {
Some(sum / count as f64)
} else {
None
};
RasterStatistics {
min,
max,
mean,
valid_count: count,
nodata_count: self.len() - count,
}
}
}
#[derive(Debug, Clone)]
pub struct RasterStatistics<T> {
pub min: Option<T>,
pub max: Option<T>,
pub mean: Option<f64>,
pub valid_count: usize,
pub nodata_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_raster_creation() {
let raster: Raster<f32> = Raster::new(100, 200);
assert_eq!(raster.rows(), 100);
assert_eq!(raster.cols(), 200);
assert_eq!(raster.shape(), (100, 200));
}
#[test]
fn test_raster_access() {
let mut raster: Raster<f32> = Raster::new(10, 10);
raster.set(5, 5, 42.0).unwrap();
assert_eq!(raster.get(5, 5).unwrap(), 42.0);
}
#[test]
fn test_raster_statistics() {
let mut raster: Raster<f32> = Raster::new(10, 10);
for i in 0..10 {
for j in 0..10 {
raster.set(i, j, (i * 10 + j) as f32).unwrap();
}
}
let stats = raster.statistics();
assert_eq!(stats.min, Some(0.0));
assert_eq!(stats.max, Some(99.0));
assert_eq!(stats.valid_count, 100);
}
}