#![allow(clippy::doc_markdown)]
use std::io::Cursor;
use std::sync::Arc;
use bytes::Bytes;
use cudarc::driver::{CudaContext, CudaStream};
use dlpark::SafeManagedTensorVersioned;
use dlpark::ffi::{DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION, Device};
use numpy::{PyArray1, ToPyArray};
use pyo3::exceptions::{PyBufferError, PyNotImplementedError, PyValueError, PyWarning};
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::define_stub_info_gatherer;
use pyo3_stub_gen_derive::{gen_stub_pyclass, gen_stub_pymethods};
use crate::io::nvtiff::CudaCogReader;
use crate::python::adapters::path_to_stream;
use crate::traits::Transform;
#[gen_stub_pyclass]
#[pyclass(unsendable)]
#[pyo3(name = "CudaCogReader")]
pub(crate) struct PyCudaCogReader {
inner: CudaCogReader,
device: Device,
}
#[gen_stub_pymethods]
#[pymethods]
impl PyCudaCogReader {
#[new]
#[pyo3(signature = (path, device_id))]
fn new(path: &str, device_id: usize) -> PyResult<Self> {
let stream: Cursor<Bytes> = path_to_stream(path)?;
let bytes: Bytes = stream.into_inner();
let cog =
CudaCogReader::new(&bytes).map_err(|err| PyValueError::new_err(err.to_string()))?;
Ok(Self {
inner: cog,
device: Device::cuda(device_id),
})
}
#[gen_stub(override_return_type(type_repr="types.CapsuleType", imports=("types")))]
#[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
fn __dlpack__(
&self,
stream: Option<u8>,
max_version: Option<(u32, u32)>,
dl_device: Option<(usize, usize)>,
copy: Option<bool>,
) -> PyResult<SafeManagedTensorVersioned> {
let device: Device = if let Some((device_type_int, device_id)) = dl_device {
match device_type_int {
2 => Ok(Device::cuda(device_id)),
_ => Err(PyBufferError::new_err(format!(
"Only DLPack device_type 2 (CUDA) is allowed, got {device_type_int}"
))),
}
} else {
Ok(self.device)
}?;
let ctx: Arc<CudaContext> = CudaContext::new(usize::try_from(device.device_id)?)
.map_err(|err| PyValueError::new_err(err.to_string()))?;
let cuda_stream: Arc<CudaStream> = if let Some(s_uint) = stream {
match s_uint {
0 => unreachable!(), 1 => Ok(ctx.default_stream()), 2 => Ok(ctx.per_thread_stream()), 3.. => Err(PyNotImplementedError::new_err(
"only legacy default stream (1) or per-thread default stream (2) is
supported for now, got {s_uint}",
)),
}
} else {
Ok(ctx.default_stream()) }?;
let _dlpack_version: PyResult<_> = if let Some((major, minor)) = max_version {
if major >= DLPACK_MAJOR_VERSION && minor == DLPACK_MINOR_VERSION {
Ok(())
} else if major == DLPACK_MAJOR_VERSION {
Err(PyWarning::new_err(format!(
"DLPack minor version mismatch: producer {DLPACK_MINOR_VERSION}, consumer {minor}. \
Using compatibility mode since major version ({DLPACK_MAJOR_VERSION}) is equal."
)))
} else {
Err(PyNotImplementedError::new_err(
"Only supporting DLPack version {}.{}, but got {major}.{minor}",
))
}
} else {
Err(PyBufferError::new_err("DLPack 0.X not supported"))
};
if copy.is_some() {
dbg!(copy);
Err(PyNotImplementedError::new_err(
"`copy!=None` argument is not yet implemented.",
))
} else {
Ok(())
}?;
let tensor: SafeManagedTensorVersioned = self
.inner
.dlpack(&cuda_stream)
.map_err(|err| PyValueError::new_err(err.to_string()))?;
Ok(tensor)
}
fn __dlpack_device__(&self) -> (i32, i32) {
(self.device.device_type as i32, self.device.device_id)
}
#[allow(clippy::type_complexity)]
fn xy_coords<'py>(
&self,
py: Python<'py>,
) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
let (x_coords, y_coords) = self
.inner
.xy_coords()
.map_err(|err| PyValueError::new_err(err.frame().to_string()))?;
Ok((x_coords.to_pyarray(py), y_coords.to_pyarray(py)))
}
}
define_stub_info_gatherer!(stub_info);