use std::ffi::c_void;
use std::sync::Arc;
use bytes::Bytes;
use cudarc::driver::{CudaSlice, CudaStream, CudaView};
use dlpark::SafeManagedTensorVersioned;
use dlpark::ffi::{DataType, DataTypeCode};
use dlpark::traits::{InferDataType, TensorView};
use exn::{OptionExt, ResultExt};
use geo::AffineTransform;
use ndarray::{Array, Array1};
use nvtiff_sys::result::{NvTiffError, NvTiffStatusError};
use nvtiff_sys::{
NvTiffResultCheck, nvtiffDecodeCheckSupported, nvtiffDecodeImage, nvtiffDecodeParams,
nvtiffDecodeParamsCreate, nvtiffDecoder, nvtiffDecoderCreateSimple, nvtiffImageInfo,
nvtiffSampleFormat, nvtiffStatus, nvtiffStream, nvtiffStreamCreate, nvtiffStreamGetImageInfo,
nvtiffStreamGetNumImages, nvtiffStreamGetTagValue, nvtiffStreamParse, nvtiffTag,
};
use crate::traits::Transform;
type NvTiffResult<T> = exn::Result<T, NvTiffError>;
pub struct CudaCogReader {
tiff_stream: *mut nvtiffStream,
image_info: nvtiffImageInfo,
}
impl CudaCogReader {
pub fn new(byte_stream: &Bytes) -> NvTiffResult<Self> {
let mut host_stream = std::mem::MaybeUninit::uninit();
let mut tiff_stream: *mut nvtiffStream = host_stream.as_mut_ptr();
let status_cpustream: nvtiffStatus::Type =
unsafe { nvtiffStreamCreate(&raw mut tiff_stream) };
status_cpustream.result()?;
let status_parse: u32 =
unsafe { nvtiffStreamParse(byte_stream.as_ptr(), byte_stream.len(), tiff_stream) };
status_parse.result()?;
let mut num_images: u32 = 0;
let status_numimages: u32 =
unsafe { nvtiffStreamGetNumImages(tiff_stream, &raw mut num_images) };
status_numimages.result()?;
let mut image_info = nvtiffImageInfo::default();
let status_imageinfo: u32 = unsafe {
nvtiffStreamGetImageInfo(
tiff_stream,
0, &raw mut image_info,
)
};
status_imageinfo.result()?;
Ok(Self {
tiff_stream,
image_info,
})
}
pub fn dlpack(&self, stream: &Arc<CudaStream>) -> NvTiffResult<SafeManagedTensorVersioned> {
let cuda_stream: *mut nvtiff_sys::CUstream_st = stream.cu_stream().cast::<_>();
let mut decoder_handle = std::mem::MaybeUninit::zeroed();
let mut nvtiff_decoder: *mut nvtiffDecoder = decoder_handle.as_mut_ptr();
let status_decoder: u32 =
unsafe { nvtiffDecoderCreateSimple(&raw mut nvtiff_decoder, cuda_stream) };
status_decoder.result()?;
let sample_format: u32 = self.image_info.sample_format[0];
let dtype_code: DataTypeCode = match sample_format {
nvtiffSampleFormat::NVTIFF_SAMPLEFORMAT_UINT => DataTypeCode::UInt,
nvtiffSampleFormat::NVTIFF_SAMPLEFORMAT_INT => DataTypeCode::Int,
nvtiffSampleFormat::NVTIFF_SAMPLEFORMAT_IEEEFP => DataTypeCode::Float,
_ => unimplemented!(
"non uint/int/float dtypes (e.g. complex int/float) not implemented yet"
),
};
let bits: u16 = self.image_info.bits_per_pixel / self.image_info.samples_per_pixel;
let dtype: DataType = DataType {
code: dtype_code,
bits: u8::try_from(bits)
.or_raise(|| NvTiffError::StatusError(NvTiffStatusError::TiffNotSupported))?,
lanes: 1,
};
let bytes_per_pixel: usize = self.image_info.bits_per_pixel as usize / 8;
let num_bytes: usize = self.image_info.image_width as usize * self.image_info.image_height as usize * bytes_per_pixel; let cuslice: CudaSlice<u8> = stream.alloc_zeros::<u8>(num_bytes).unwrap_or_else(|err| {
panic!("Failed to allocate {num_bytes} bytes on CUDA device: {err}")
});
let mut params = std::mem::MaybeUninit::zeroed();
let mut decode_params: *mut nvtiffDecodeParams = params.as_mut_ptr();
let status_param: u32 = unsafe { nvtiffDecodeParamsCreate(&raw mut decode_params) };
status_param.result()?;
let status_check: u32 = unsafe {
nvtiffDecodeCheckSupported(
self.tiff_stream, nvtiff_decoder,
decode_params,
0, )
};
status_check.result()?;
let len_elem: usize = num_bytes / (dtype.bits as usize / 8);
let tensor: SafeManagedTensorVersioned = match dtype {
DataType::U8 => SafeManagedTensorVersioned::new(cuslice)
.or_raise(|| NvTiffError::StatusError(NvTiffStatusError::AllocatorFailure))?,
DataType::U16 => cudaslice_to_tensor::<u16>(cuslice, len_elem)?,
DataType::U32 => cudaslice_to_tensor::<u32>(cuslice, len_elem)?,
DataType::U64 => cudaslice_to_tensor::<u64>(cuslice, len_elem)?,
DataType::I8 => cudaslice_to_tensor::<i8>(cuslice, len_elem)?,
DataType::I16 => cudaslice_to_tensor::<i16>(cuslice, len_elem)?,
DataType::I32 => cudaslice_to_tensor::<i32>(cuslice, len_elem)?,
DataType::I64 => cudaslice_to_tensor::<i64>(cuslice, len_elem)?,
DataType::F32 => cudaslice_to_tensor::<f32>(cuslice, len_elem)?,
DataType::F64 => cudaslice_to_tensor::<f64>(cuslice, len_elem)?,
dtype => {
unimplemented!("Converting {dtype:?} into DLPack not supported yet.")
}
};
let status_decode: u32 = unsafe {
nvtiffDecodeImage(
self.tiff_stream,
nvtiff_decoder,
decode_params,
0, tensor.data_ptr(),
cuda_stream,
)
};
status_decode.result()?;
Ok(tensor)
}
}
impl Transform for &CudaCogReader {
type Err = NvTiffError;
fn transform(self) -> NvTiffResult<AffineTransform<f64>> {
let transformation = &mut [f64::NAN; 16];
let status_transformationinfo: u32 = unsafe {
nvtiffStreamGetTagValue(
self.tiff_stream,
0, nvtiffTag::NVTIFF_TAG_MODEL_TRANSFORMATION,
transformation.as_mut_ptr().cast::<c_void>(),
16,
)
};
let (x_rotation, y_rotation): (f64, f64) = match status_transformationinfo.result() {
Ok(()) => {
unimplemented!("ModelTransformationTag and/or non-zero rotation not supported yet")
}
Err(_) => (0.0, 0.0),
};
let pixel_scale = &mut [f64::NAN; 3];
let status_pixelscaleinfo: u32 = unsafe {
nvtiffStreamGetTagValue(
self.tiff_stream,
0, nvtiffTag::NVTIFF_TAG_MODEL_PIXEL_SCALE,
pixel_scale.as_mut_ptr().cast::<c_void>(),
3,
)
};
status_pixelscaleinfo.result()?;
let [x_scale, y_scale, _z_scale] = *pixel_scale;
let tie_points = &mut [f64::NAN; 6];
let status_tiepointinfo: u32 = unsafe {
nvtiffStreamGetTagValue(
self.tiff_stream,
0, nvtiffTag::NVTIFF_TAG_MODEL_TIE_POINT,
tie_points.as_mut_ptr().cast::<c_void>(),
6,
)
};
status_tiepointinfo.result()?;
let [_i, _j, _k, x_origin, y_origin, _z_origin] = *tie_points;
let transform = AffineTransform::new(
x_scale, x_rotation, x_origin, y_rotation, -y_scale, y_origin,
);
Ok(transform)
}
fn xy_coords(self) -> NvTiffResult<(Array1<f64>, Array1<f64>)> {
let transform: AffineTransform = self.transform()?;
let x_res: &f64 = &transform.a();
let y_res: &f64 = &transform.e();
let x_origin: &f64 = &(transform.xoff() + x_res / 2.0);
let y_origin: &f64 = &(transform.yoff() + y_res / 2.0);
let x_pixels: u32 = self.image_info.image_width;
let y_pixels: u32 = self.image_info.image_height;
let x_end: f64 = x_origin + x_res * f64::from(x_pixels);
let y_end: f64 = y_origin + y_res * f64::from(y_pixels);
let x_coords = Array::range(x_origin.to_owned(), x_end, x_res.to_owned());
let y_coords = Array::range(y_origin.to_owned(), y_end, y_res.to_owned());
Ok((x_coords, y_coords))
}
}
fn cudaslice_to_tensor<T: InferDataType>(
cuslice: CudaSlice<u8>,
len_elem: usize,
) -> NvTiffResult<SafeManagedTensorVersioned> {
let cuview: CudaView<_> = unsafe { cuslice.transmute::<T>(len_elem) }
.ok_or_raise(|| NvTiffError::StatusError(NvTiffStatusError::BadTiff))?;
let tensor = SafeManagedTensorVersioned::new(cuview)
.or_raise(|| NvTiffError::StatusError(NvTiffStatusError::AllocatorFailure))?;
cuslice.leak();
Ok(tensor)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use cudarc::driver::{CudaContext, CudaSlice, CudaStream};
use dlpark::SafeManagedTensorVersioned;
use dlpark::ffi::DataType;
use dlpark::prelude::TensorView;
use geo::AffineTransform;
use ndarray::Array;
use object_store::parse_url;
use rstest::rstest;
use url::Url;
use crate::io::nvtiff::CudaCogReader;
use crate::traits::Transform;
#[tokio::test]
async fn cudacogreader_dlpack() {
let cog_url: &str =
"https://github.com/rasterio/rasterio/raw/refs/tags/1.4.3/tests/data/float32.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let ctx: Arc<CudaContext> = cudarc::driver::CudaContext::new(0).unwrap(); let cuda_stream: Arc<CudaStream> = ctx.per_thread_stream();
let cog: CudaCogReader = CudaCogReader::new(&bytes).unwrap();
let tensor: SafeManagedTensorVersioned = cog.dlpack(&cuda_stream).unwrap();
assert_eq!(tensor.data_type(), &DataType::F32);
assert_eq!(tensor.shape(), [6]);
let mut image_out_h: Vec<f32> = vec![0.0; tensor.num_elements()];
let cuslice: CudaSlice<_> = unsafe {
cuda_stream.upgrade_device_ptr(tensor.data_ptr() as u64, tensor.num_elements())
};
cuda_stream
.memcpy_dtoh(&cuslice.clone(), &mut image_out_h)
.unwrap();
dbg!(image_out_h.clone());
assert_eq!(image_out_h, vec![1.41, 1.23, 0.78, 0.32, -0.23, -1.88]);
}
#[rstest]
#[case::u8("byte.tif", DataType::U8)]
#[case::u16("uint16.tif", DataType::U16)]
#[case::u32("uint32.tif", DataType::U32)]
#[case::i16("int16.tif", DataType::I16)]
#[case::i32("int32.tif", DataType::I32)]
#[tokio::test]
async fn cudacogreader_dlpack_uint_int_dtypes(#[case] filename: &str, #[case] dtype: DataType) {
let cog_url: String = format!(
"https://github.com/OSGeo/gdal/raw/v3.12.0beta1/autotest/gcore/data/{filename}",
);
let tif_url = Url::parse(cog_url.as_str()).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let ctx: Arc<CudaContext> = cudarc::driver::CudaContext::new(0).unwrap(); let cuda_stream: Arc<CudaStream> = ctx.per_thread_stream();
let cog: CudaCogReader = CudaCogReader::new(&bytes).unwrap();
let tensor: SafeManagedTensorVersioned = cog.dlpack(&cuda_stream).unwrap();
assert_eq!(tensor.data_type(), &dtype);
assert_eq!(tensor.shape(), [400]);
}
#[tokio::test]
async fn unimplemented_error() {
let cog_url: &str =
"https://github.com/image-rs/image-tiff/raw/v0.11.2/tests/images/tiled-cmyk-i8.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let ctx: Arc<CudaContext> = cudarc::driver::CudaContext::new(0).unwrap(); let cuda_stream: Arc<CudaStream> = ctx.per_thread_stream();
let cog = CudaCogReader::new(&bytes).unwrap();
let result = cog.dlpack(&cuda_stream);
assert_eq!(
result.err().unwrap().to_string(),
"Status error: Attempting to decode a TIFF stream that is not supported by the nvTIFF library."
);
}
#[tokio::test]
async fn cudacogreader_transform() {
let cog_url: &str =
"https://github.com/cogeotiff/rio-tiler/raw/8.0.5/tests/fixtures/cog_nodata_nan.tif";
let tif_url = Url::parse(cog_url).unwrap();
let (store, location) = parse_url(&tif_url).unwrap();
let result = store.get(&location).await.unwrap();
let bytes = result.bytes().await.unwrap();
let cog: CudaCogReader = CudaCogReader::new(&bytes).unwrap();
let transform = cog.transform().unwrap();
assert_eq!(
transform,
AffineTransform::new(200.0, 0.0, 499_980.0, 0.0, -200.0, 5_300_040.0)
);
let (x_coords, y_coords) = cog.xy_coords().unwrap();
assert_eq!(x_coords, Array::linspace(500_080., 609_680., 549));
assert_eq!(y_coords, Array::linspace(5_299_940.0, 5_190_340.0, 549));
}
}