use std::io::{Error, Read, Seek};
use dlpark::SafeManagedTensorVersioned;
use dlpark::traits::InferDataType;
use exn::{ResultExt, bail};
use geo::AffineTransform;
use ndarray::{Array, Array1, Array3, ArrayView3, ArrayViewD};
use tiff::decoder::{Decoder, DecodingResult, Limits};
use tiff::tags::Tag;
use tiff::{ColorType, TiffError, TiffFormatError, TiffUnsupportedError};
use crate::traits::Transform;
type TiffResult<T> = exn::Result<T, TiffError>;
pub struct CogReader<R: Read + Seek> {
decoder: Decoder<R>,
}
impl<R: Read + Seek> CogReader<R> {
pub fn new(stream: R) -> TiffResult<Self> {
let mut decoder = Decoder::new(stream)?;
decoder = decoder.with_limits(Limits::unlimited());
Ok(Self { decoder })
}
#[allow(clippy::missing_errors_doc)]
pub fn dlpack(&mut self) -> TiffResult<SafeManagedTensorVersioned> {
let num_bands: usize = self.num_samples()?;
let (width, height): (u32, u32) = self.decoder.dimensions()?;
let decode_result = self.decoder.read_image()?;
let shape = (num_bands, height as usize, width as usize);
let tensor: SafeManagedTensorVersioned = match decode_result {
DecodingResult::U8(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::U16(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::U32(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::U64(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::I8(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::I16(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::I32(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::I64(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::F16(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::F32(img_data) => shape_vec_to_tensor(shape, img_data)?,
DecodingResult::F64(img_data) => shape_vec_to_tensor(shape, img_data)?,
};
Ok(tensor)
}
fn num_samples(&mut self) -> TiffResult<usize> {
let color_type = self.decoder.colortype()?;
let num_bands: usize = match color_type {
ColorType::Multiband {
bit_depth: _,
num_samples,
} => num_samples as usize,
ColorType::Gray(_) => 1,
ColorType::RGB(_) => 3,
_ => {
bail!(TiffError::UnsupportedError(
TiffUnsupportedError::UnsupportedColorType(color_type),
));
}
};
Ok(num_bands)
}
}
impl<R: Read + Seek> Transform for &mut CogReader<R> {
type Err = TiffError;
fn transform(self) -> TiffResult<AffineTransform<f64>> {
let (x_rotation, y_rotation): (f64, f64) =
match self.decoder.get_tag_f64_vec(Tag::ModelTransformationTag) {
Ok(_model_transformation) => unimplemented!("Non-zero rotation is not handled yet"),
Err(_) => (0.0, 0.0),
};
let pixel_scale: Vec<f64> = self.decoder.get_tag_f64_vec(Tag::ModelPixelScaleTag)?;
let [x_scale, y_scale, _z_scale] = pixel_scale[0..3] else {
bail!(TiffError::FormatError(TiffFormatError::InvalidTag));
};
let tie_points: Vec<f64> = self.decoder.get_tag_f64_vec(Tag::ModelTiepointTag)?;
let [_i, _j, _k, x_origin, y_origin, _z_origin] = tie_points[0..6] else {
bail!(TiffError::FormatError(TiffFormatError::InvalidTag));
};
let transform = AffineTransform::new(
x_scale, x_rotation, x_origin, y_rotation, -y_scale, y_origin,
);
Ok(transform)
}
fn xy_coords(self) -> TiffResult<(Array1<f64>, Array1<f64>)> {
let transform = self.transform()?;
let x_res: &f64 = &transform.a();
let y_res: &f64 = &transform.e();
let x_origin: &f64 = &(transform.xoff() + x_res / 2.0);
let y_origin: &f64 = &(transform.yoff() + y_res / 2.0);
let (x_pixels, y_pixels): (u32, u32) = self.decoder.dimensions()?;
let x_end: f64 = x_origin + x_res * f64::from(x_pixels);
let y_end: f64 = y_origin + y_res * f64::from(y_pixels);
let x_coords = Array::range(x_origin.to_owned(), x_end, x_res.to_owned());
let y_coords = Array::range(y_origin.to_owned(), y_end, y_res.to_owned());
Ok((x_coords, y_coords))
}
}
fn shape_vec_to_tensor<T: InferDataType>(
shape: (usize, usize, usize),
vec: Vec<T>,
) -> TiffResult<SafeManagedTensorVersioned> {
let size: usize = vec.len();
let array_data = Array3::from_shape_vec(shape, vec).or_raise(|| {
TiffError::IoError(Error::other(
format!("failed to convert vector of size {size} to shape {shape:?}").to_string(),
))
})?;
let tensor = SafeManagedTensorVersioned::new(array_data).or_raise(|| {
TiffError::IoError(Error::other(
"failed to convert array to DLPack tensor".to_string(),
))
})?;
Ok(tensor)
}
pub fn read_geotiff<T: InferDataType + Clone, R: Read + Seek>(stream: R) -> TiffResult<Array3<T>> {
let mut reader = CogReader::new(stream)?;
let tensor: SafeManagedTensorVersioned = reader.dlpack()?;
let num_bands: usize = reader.num_samples()?;
let (width, height): (u32, u32) = reader.decoder.dimensions()?;
let view = ArrayViewD::<T>::try_from(&tensor).or_raise(|| {
TiffError::IoError(Error::other(
"failed to convert DLPack tensor into an Array".to_string(),
))
})?;
let shape = (num_bands, height as usize, width as usize);
let array: ArrayView3<T> = view
.into_shape_with_order((num_bands, height as usize, width as usize))
.or_raise(|| {
TiffError::IoError(Error::other(
format!("failed to reshape Array into shape {shape:?}").to_string(),
))
})?;
Ok(array.to_owned())
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Seek, SeekFrom};
use dlpark::SafeManagedTensorVersioned;
use dlpark::ffi::DataType;
use dlpark::prelude::TensorView;
use geo::AffineTransform;
use ndarray::{Array, Array3, s};
use object_store::parse_url;
use tempfile::tempfile;
use tiff::encoder::{TiffEncoder, colortype};
use url::Url;
use crate::io::geotiff::{CogReader, read_geotiff, shape_vec_to_tensor};
use crate::traits::Transform;
#[test]
fn test_read_geotiff() {
let mut image_data = Vec::new();
for y in 0..10u8 {
for x in 0..20u8 {
let val = y + x;
image_data.push(f32::from(val));
}
}
let mut file = tempfile().unwrap();
let mut bigtiff = TiffEncoder::new_big(&mut file).unwrap();
bigtiff
.write_image::<colortype::Gray32Float>(20, 10, &image_data) .unwrap();
file.seek(SeekFrom::Start(0)).unwrap();
let arr: Array3<f32> = read_geotiff(file).unwrap();
assert_eq!(arr.ndim(), 3);
assert_eq!(arr.dim(), (1, 10, 20)); let first_band = arr.slice(s![0, .., ..]);
assert_eq!(first_band.nrows(), 10); assert_eq!(first_band.ncols(), 20); assert_eq!(arr.mean(), Some(14.0));
}
#[tokio::test]
async fn test_read_geotiff_multi_band() {
let cog_url: &str = "https://github.com/locationtech/geotrellis/raw/v3.7.1/raster/data/one-month-tiles-multiband/result.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let stream = Cursor::new(bytes);
let array: Array3<f32> = read_geotiff(stream).unwrap();
assert_eq!(array.dim(), (2, 512, 512));
assert_eq!(array.mean(), Some(225.17654));
}
#[tokio::test]
async fn test_read_geotiff_uint16_dtype() {
let cog_url: &str =
"https://github.com/OSGeo/gdal/raw/v3.9.2/autotest/gcore/data/uint16.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let stream = Cursor::new(bytes);
let array: Array3<u16> = read_geotiff(stream).unwrap();
assert_eq!(array.dim(), (1, 20, 20));
assert_eq!(array.mean(), Some(126));
}
#[tokio::test]
async fn test_read_geotiff_float16_dtype() {
let cog_url: &str =
"https://github.com/OSGeo/gdal/raw/v3.11.0/autotest/gcore/data/float16.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let stream = Cursor::new(bytes);
let array: Array3<half::f16> = read_geotiff(stream).unwrap();
assert_eq!(array.dim(), (1, 20, 20));
assert_eq!(array.mean(), Some(half::f16::from_f32_const(127.125)));
}
#[test]
fn reshape_error() {
let result = shape_vec_to_tensor((1, 2, 3), vec![0, 1, 2]);
assert_eq!(
unsafe { result.unwrap_err_unchecked().to_string() },
"failed to convert vector of size 3 to shape (1, 2, 3)"
);
}
#[tokio::test]
async fn test_cogreader_dlpack() {
let cog_url: &str = "https://github.com/rasterio/rasterio/raw/1.3.9/tests/data/float32.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let stream = Cursor::new(bytes);
let mut cog = CogReader::new(stream).unwrap();
let tensor: SafeManagedTensorVersioned = cog.dlpack().unwrap();
assert_eq!(tensor.shape(), [1, 2, 3]);
assert_eq!(tensor.data_type(), &DataType::F32);
let values: Vec<f32> = tensor
.as_slice_untyped()
.to_vec()
.chunks_exact(4)
.map(TryInto::try_into)
.map(Result::unwrap)
.map(f32::from_le_bytes)
.collect();
assert_eq!(values, vec![1.41, 1.23, 0.78, 0.32, -0.23, -1.88]);
}
#[tokio::test]
async fn test_cogreader_num_samples() {
let cog_url: &str = "https://github.com/developmentseed/titiler/raw/refs/tags/0.22.2/src/titiler/mosaic/tests/fixtures/TCI.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let stream = Cursor::new(bytes);
let mut cog = CogReader::new(stream).unwrap();
assert_eq!(cog.num_samples().unwrap(), 3);
}
#[tokio::test]
async fn cogreader_transform() {
let cog_url: &str =
"https://github.com/cogeotiff/rio-tiler/raw/8.0.5/tests/fixtures/cog_nodata_nan.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let stream = Cursor::new(bytes);
let mut cog = CogReader::new(stream).unwrap();
let transform = cog.transform().unwrap();
assert_eq!(
transform,
AffineTransform::new(200.0, 0.0, 499_980.0, 0.0, -200.0, 5_300_040.0)
);
let (x_coords, y_coords) = cog.xy_coords().unwrap();
assert_eq!(x_coords, Array::linspace(500_080., 609_680., 549));
assert_eq!(y_coords, Array::linspace(5_299_940.0, 5_190_340.0, 549));
}
}