#![allow(clippy::doc_markdown)]
use ndarray::{ArrayD, IxDyn};
use numpy::{IntoPyArray, PyArrayDyn, PyArrayMethods, PyReadonlyArrayDyn};
use pyo3::exceptions::{
PyFileNotFoundError, PyIOError, PyMemoryError, PyStopIteration, PyValueError,
};
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use crate::error::Error as MedrsError;
use crate::nifti::image::ArrayData;
fn to_py_err(e: MedrsError, context: &str) -> PyErr {
match &e {
MedrsError::Io(io_err) => {
PyIOError::new_err(format!("{}: {}", context, io_err))
}
MedrsError::MemoryAllocation(msg) => {
PyMemoryError::new_err(format!("{}: {}", context, msg))
}
MedrsError::InvalidDimensions(msg)
| MedrsError::InvalidAffine(msg)
| MedrsError::InvalidCropRegion(msg)
| MedrsError::ShapeMismatch(msg)
| MedrsError::InvalidFileFormat(msg)
| MedrsError::InvalidOrientation(msg)
| MedrsError::NonContiguousArray(msg)
| MedrsError::Configuration(msg)
| MedrsError::Decompression(msg)
| MedrsError::Exhausted(msg) => PyValueError::new_err(format!("{}: {}", context, msg)),
MedrsError::InvalidMagic(magic) => PyValueError::new_err(format!(
"{}: invalid NIfTI magic bytes {:?}",
context, magic
)),
MedrsError::UnsupportedDataType(code) => {
PyValueError::new_err(format!("{}: unsupported data type code {}", context, code))
}
MedrsError::DataTypeMismatch { expected, got } => PyValueError::new_err(format!(
"{}: data type mismatch (expected {}, got {})",
context, expected, got
)),
MedrsError::TransformError { operation, reason } => {
PyValueError::new_err(format!("{}: {} failed: {}", context, operation, reason))
}
}
}
use crate::nifti::DataType;
use crate::nifti::{self, NiftiImage as RustNiftiImage};
use crate::pipeline::TransformPipeline as RustTransformPipeline;
use crate::transforms::crop::{
compute_center_crop_regions, compute_label_aware_crop_regions,
compute_random_spatial_crop_regions,
};
use crate::transforms::{self, Interpolation, Orientation};
fn validate_shape(shape: &[usize; 3], name: &str) -> PyResult<()> {
for (i, &dim) in shape.iter().enumerate() {
if dim == 0 {
return Err(PyValueError::new_err(format!(
"{} dimension {} must be positive (got 0)",
name, i
)));
}
}
Ok(())
}
fn validate_spacing(spacing: &[f32; 3], name: &str) -> PyResult<()> {
for (i, &s) in spacing.iter().enumerate() {
if s <= 0.0 {
return Err(PyValueError::new_err(format!(
"{} dimension {} must be positive (got {})",
name, i, s
)));
}
if !s.is_finite() {
return Err(PyValueError::new_err(format!(
"{} dimension {} must be finite (got {})",
name, i, s
)));
}
}
Ok(())
}
fn parse_shape3(values: &[usize], name: &str) -> PyResult<[usize; 3]> {
if values.len() != 3 {
return Err(PyValueError::new_err(format!(
"{} must be a 3-element sequence (got {})",
name,
values.len()
)));
}
let shape = [values[0], values[1], values[2]];
validate_shape(&shape, name)?;
Ok(shape)
}
fn validate_file_path(path: &str, operation: &str) -> PyResult<std::path::PathBuf> {
if path.is_empty() {
return Err(PyValueError::new_err(format!(
"{}: file path cannot be empty",
operation
)));
}
if path.contains('\0') {
return Err(PyValueError::new_err(format!(
"{}: file path cannot contain null bytes",
operation
)));
}
let path_buf = std::path::PathBuf::from(path);
if operation.contains("load") || operation.contains("read") {
if !path_buf.exists() {
return Err(PyFileNotFoundError::new_err(format!(
"{}: file not found: {}",
operation, path
)));
}
if !path_buf.is_file() {
return Err(PyValueError::new_err(format!(
"{}: path is not a file: {}",
operation, path
)));
}
}
if operation.contains("save") || operation.contains("write") {
if let Some(parent) = path_buf.parent() {
if !parent.exists() {
return Err(PyFileNotFoundError::new_err(format!(
"{}: parent directory does not exist: {}",
operation,
parent.display()
)));
}
}
}
Ok(path_buf)
}
fn validate_intensity_range(min: f64, max: f64, param_name: &str) -> PyResult<()> {
if !min.is_finite() {
return Err(PyValueError::new_err(format!(
"{}: min value must be finite (got {})",
param_name, min
)));
}
if !max.is_finite() {
return Err(PyValueError::new_err(format!(
"{}: max value must be finite (got {})",
param_name, max
)));
}
if min > max {
return Err(PyValueError::new_err(format!(
"{}: min ({}) cannot be greater than max ({})",
param_name, min, max
)));
}
Ok(())
}
fn validate_probability(p: f64, param_name: &str) -> PyResult<()> {
if !p.is_finite() {
return Err(PyValueError::new_err(format!(
"{}: probability must be finite (got {})",
param_name, p
)));
}
if !(0.0..=1.0).contains(&p) {
return Err(PyValueError::new_err(format!(
"{}: probability must be between 0.0 and 1.0 (got {})",
param_name, p
)));
}
Ok(())
}
fn create_nifti_from_numpy_array(
arr: ndarray::ArrayViewD<'_, f32>,
affine: Option<[[f32; 4]; 4]>,
) -> PyResult<RustNiftiImage> {
use ndarray::ShapeBuilder;
let shape = arr.shape();
if shape.len() < 3 {
return Err(PyValueError::new_err(
"Array must have at least 3 dimensions (D,H,W)",
));
}
for (i, &dim) in shape.iter().enumerate() {
if dim == 0 {
return Err(PyValueError::new_err(format!(
"Array dimension {} cannot be 0",
i
)));
}
}
for (i, &dim) in shape.iter().enumerate() {
if dim > u16::MAX as usize {
return Err(PyValueError::new_err(format!(
"Array dimension {} ({}) exceeds maximum NIfTI dimension size ({})",
i,
dim,
u16::MAX
)));
}
}
#[allow(clippy::option_if_let_else)]
let data_vec: Vec<f32> = if let Some(slice) = arr.as_slice_memory_order() {
slice.to_vec()
} else {
arr.iter().copied().collect()
};
let is_f_order = !arr.is_standard_layout() && arr.as_slice_memory_order().is_some();
let array = if is_f_order {
ArrayD::from_shape_vec(IxDyn(shape).f(), data_vec)
.map_err(|e| PyValueError::new_err(format!("Invalid array shape: {}", e)))?
} else {
let c_order = ArrayD::from_shape_vec(shape.to_vec(), data_vec)
.map_err(|e| PyValueError::new_err(format!("Invalid array shape: {}", e)))?;
let mut f_order = ArrayD::zeros(IxDyn(shape).f());
f_order.assign(&c_order);
f_order
};
let affine = affine.unwrap_or([
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]);
Ok(RustNiftiImage::from_array(array, affine))
}
#[pyclass(name = "NiftiImage")]
struct PyNiftiImage {
inner: RustNiftiImage,
}
#[pymethods]
impl PyNiftiImage {
#[new]
#[pyo3(signature = (data, affine=None))]
fn new(data: PyReadonlyArrayDyn<'_, f32>, affine: Option<[[f32; 4]; 4]>) -> PyResult<Self> {
let inner = create_nifti_from_numpy_array(data.as_array(), affine)?;
Ok(Self { inner })
}
#[getter]
fn shape(&self) -> Vec<usize> {
self.inner.shape().to_vec()
}
#[getter]
fn ndim(&self) -> usize {
self.inner.ndim()
}
#[getter]
fn dtype(&self) -> &'static str {
self.inner.dtype().type_name()
}
#[getter]
fn spacing(&self) -> Vec<f32> {
self.inner.spacing().to_vec()
}
#[getter]
fn affine(&self) -> [[f32; 4]; 4] {
self.inner.affine()
}
#[setter]
fn set_affine(&mut self, affine: [[f32; 4]; 4]) {
self.inner.set_affine(affine);
}
#[getter]
fn orientation(&self) -> String {
let affine = self.inner.affine();
transforms::orientation_from_affine(&affine).to_string()
}
#[getter]
fn data<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArrayDyn<f32>>> {
self.to_numpy(py)
}
fn to_numpy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArrayDyn<f32>>> {
if let Some(arr) = to_numpy_view(py, &self.inner) {
return Ok(arr);
}
to_numpy_array(py, &self.inner)
}
#[allow(clippy::option_if_let_else)]
fn to_numpy_view<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArrayDyn<f32>>> {
if let Some(arr) = to_numpy_view(py, &self.inner) {
Ok(arr)
} else {
to_numpy_array(py, &self.inner)
}
}
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
let torch = py.import("torch")?;
let dtype = torch_dtype(py, self.inner.dtype());
let has_dtype = dtype.is_some();
let np_obj = if has_dtype {
if let Some(np_view) = to_numpy_view_native(py, &self.inner) {
np_view
} else {
arraydata_to_numpy(
py,
&self
.inner
.owned_data()
.map_err(|e| to_py_err(e, "to_torch owned_data"))?,
self.inner.shape(),
)?
}
} else {
to_numpy_array(py, &self.inner)?
.into_pyobject(py)?
.into_any()
.unbind()
};
let tensor = torch.getattr("from_numpy")?.call1((np_obj,))?;
if let Some(dt) = dtype {
let tensor = tensor.call_method1("to", (dt,))?;
Ok(tensor.unbind())
} else {
Ok(tensor.unbind())
}
}
fn to_jax(&self, py: Python<'_>) -> PyResult<PyObject> {
let jnp = py.import("jax.numpy")?;
let np_obj = to_numpy_array(py, &self.inner)?;
let arr = jnp.getattr("array")?.call1((np_obj,))?;
Ok(arr.unbind())
}
#[pyo3(signature = (dtype=None, device=None))]
fn to_torch_with_dtype_and_device(
&self,
py: Python<'_>,
dtype: Option<PyObject>,
device: Option<&str>,
) -> PyResult<PyObject> {
let device_str = device.unwrap_or("cpu");
let torch = py.import("torch")?;
let np_obj = if let Some(np_view) = to_numpy_view_native(py, &self.inner) {
np_view
} else {
to_numpy_array(py, &self.inner)?
.into_pyobject(py)?
.into_any()
.unbind()
};
let tensor = torch.getattr("from_numpy")?.call1((np_obj,))?;
let tensor = tensor.call_method1("to", (device_str,))?;
if let Some(dt) = dtype {
let tensor = tensor.call_method1("to", (dt,))?;
Ok(tensor.unbind())
} else {
Ok(tensor.unbind())
}
}
#[allow(clippy::useless_let_if_seq)]
#[pyo3(signature = (dtype=None, device=None))]
fn to_jax_with_dtype_and_device(
&self,
py: Python<'_>,
dtype: Option<PyObject>,
device: Option<&str>,
) -> PyResult<PyObject> {
let device_str = device.unwrap_or("cpu");
let jax = py.import("jax")?;
let _jnp = py.import("jax.numpy")?;
let device_obj = if device_str == "cpu" {
let cpu_devices = jax.getattr("devices")?.call1(("cpu",))?;
cpu_devices.get_item(0)?
} else if device_str.starts_with("cuda") || device_str.starts_with("gpu") {
let gpu_devices = jax.getattr("devices")?.call1(("gpu",))?;
let device_id: usize = device_str
.strip_prefix("cuda:")
.or_else(|| device_str.strip_prefix("gpu:"))
.unwrap_or("0")
.parse()
.unwrap_or(0);
gpu_devices.get_item(device_id)?
} else {
return Err(PyValueError::new_err(format!(
"Unsupported device: {}. Use 'cpu', 'cuda', 'cuda:N', 'gpu', or 'gpu:N'",
device_str
)));
};
let jax = py.import("jax")?;
let jnp = py.import("jax.numpy")?;
if let Some(np_view) = to_numpy_view_native(py, &self.inner) {
let mut arr = jnp.getattr("array")?.call1((np_view,))?;
if let Some(dt) = dtype {
arr = jnp.getattr("astype")?.call1((dt,))?;
}
let device_put = jax.getattr("device_put")?;
let arr = device_put.call1((arr, &device_obj))?;
return Ok(arr.into());
}
let np_obj = to_numpy_array(py, &self.inner)?;
let mut arr = jnp.getattr("array")?.call1((np_obj,))?;
if let Some(dt) = dtype {
arr = jnp.getattr("astype")?.call1((dt,))?;
}
let device_put = jax.getattr("device_put")?;
let arr = device_put.call1((arr, &device_obj))?;
Ok(arr.into())
}
fn to_numpy_native(&self, py: Python<'_>) -> PyResult<PyObject> {
arraydata_to_numpy(
py,
&self
.inner
.owned_data()
.map_err(|e| to_py_err(e, "to_numpy_native"))?,
self.inner.shape(),
)
}
#[staticmethod]
#[pyo3(signature = (data, affine=None))]
fn from_numpy<'py>(
_py: Python<'py>,
data: PyReadonlyArrayDyn<'py, f32>,
affine: Option<[[f32; 4]; 4]>,
) -> PyResult<Self> {
let inner = create_nifti_from_numpy_array(data.as_array(), affine)?;
Ok(Self { inner })
}
fn save(&self, path: &str) -> PyResult<()> {
let validated_path = validate_file_path(path, "save")?;
let path_str = validated_path
.to_str()
.ok_or_else(|| PyValueError::new_err("path contains invalid UTF-8"))?;
nifti::save(&self.inner, path_str)
.map_err(|e| to_py_err(e, &format!("Failed to save {}", path)))
}
fn with_dtype(&self, dtype: &str) -> PyResult<Self> {
let target_dtype = match dtype.to_lowercase().as_str() {
"float32" | "f32" => nifti::DataType::Float32,
"float64" | "f64" => nifti::DataType::Float64,
"float16" | "f16" => nifti::DataType::Float16,
"bfloat16" | "bf16" => nifti::DataType::BFloat16,
"int8" | "i8" => nifti::DataType::Int8,
"uint8" | "u8" => nifti::DataType::UInt8,
"int16" | "i16" => nifti::DataType::Int16,
"uint16" | "u16" => nifti::DataType::UInt16,
"int32" | "i32" => nifti::DataType::Int32,
"uint32" | "u32" => nifti::DataType::UInt32,
"int64" | "i64" => nifti::DataType::Int64,
"uint64" | "u64" => nifti::DataType::UInt64,
_ => return Err(PyValueError::new_err(format!(
"Unsupported dtype '{}'. Use: float32, float64, float16, bfloat16, int8, uint8, int16, uint16, int32, uint32, int64, uint64",
dtype
))),
};
Ok(Self {
inner: self
.inner
.with_dtype(target_dtype)
.map_err(|e| to_py_err(e, "with_dtype"))?,
})
}
#[pyo3(signature = (spacing, method=None))]
fn resample(&self, spacing: [f32; 3], method: Option<&str>) -> PyResult<Self> {
let method_str = method.unwrap_or("trilinear");
let interp = match method_str {
"trilinear" | "linear" => Interpolation::Trilinear,
"nearest" => Interpolation::Nearest,
_ => {
return Err(PyValueError::new_err(
"method must be 'trilinear' or 'nearest'",
))
}
};
let resampled = transforms::resample_to_spacing(&self.inner, spacing, interp)
.map_err(|e| PyValueError::new_err(format!("Resampling failed: {}", e)))?;
Ok(Self { inner: resampled })
}
#[pyo3(signature = (shape, method=None))]
fn resample_to_shape(&self, shape: [usize; 3], method: Option<&str>) -> PyResult<Self> {
let method_str = method.unwrap_or("trilinear");
let interp = match method_str {
"trilinear" | "linear" => Interpolation::Trilinear,
"nearest" => Interpolation::Nearest,
_ => {
return Err(PyValueError::new_err(
"method must be 'trilinear' or 'nearest'",
))
}
};
let resampled = transforms::resample_to_shape(&self.inner, shape, interp)
.map_err(|e| PyValueError::new_err(format!("Resampling failed: {}", e)))?;
Ok(Self { inner: resampled })
}
fn reorient(&self, orientation: &str) -> PyResult<Self> {
let target: Orientation = orientation
.parse()
.map_err(|e| PyValueError::new_err(format!("{}", e)))?;
Ok(Self {
inner: transforms::reorient(&self.inner, target)
.map_err(|e| PyValueError::new_err(format!("Reorientation failed: {}", e)))?,
})
}
fn z_normalize(&self) -> PyResult<Self> {
Ok(Self {
inner: transforms::z_normalization(&self.inner)
.map_err(|e| to_py_err(e, "z_normalize"))?,
})
}
fn rescale(&self, out_min: f64, out_max: f64) -> PyResult<Self> {
Ok(Self {
inner: transforms::rescale_intensity(&self.inner, out_min, out_max)
.map_err(|e| to_py_err(e, "rescale"))?,
})
}
fn clamp(&self, min: f64, max: f64) -> PyResult<Self> {
Ok(Self {
inner: transforms::clamp(&self.inner, min, max).map_err(|e| to_py_err(e, "clamp"))?,
})
}
fn crop_or_pad(&self, target_shape: Vec<usize>) -> PyResult<Self> {
Ok(Self {
inner: transforms::crop_or_pad(&self.inner, &target_shape)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
fn flip(&self, axes: Vec<usize>) -> PyResult<Self> {
Ok(Self {
inner: transforms::flip(&self.inner, &axes)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
fn is_materialized(&self) -> bool {
self.inner.is_materialized()
}
fn materialize(&self) -> PyResult<Self> {
Ok(Self {
inner: self
.inner
.materialize()
.map_err(|e| to_py_err(e, "materialize"))?,
})
}
fn __repr__(&self) -> String {
format!(
"NiftiImage(shape={:?}, dtype={}, spacing={:?}, orientation={})",
self.shape(),
self.dtype(),
self.spacing(),
self.orientation()
)
}
}
#[pyfunction]
fn z_normalization(image: &PyNiftiImage) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::z_normalization(&image.inner)
.map_err(|e| to_py_err(e, "z_normalization"))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, output_range=(0.0, 1.0)))]
fn rescale_intensity(image: &PyNiftiImage, output_range: (f64, f64)) -> PyResult<PyNiftiImage> {
let (out_min, out_max) = output_range;
Ok(PyNiftiImage {
inner: transforms::rescale_intensity(&image.inner, out_min, out_max)
.map_err(|e| to_py_err(e, "rescale_intensity"))?,
})
}
#[pyfunction]
fn clamp(image: &PyNiftiImage, min_value: f64, max_value: f64) -> PyResult<PyNiftiImage> {
validate_intensity_range(min_value, max_value, "clamp")?;
Ok(PyNiftiImage {
inner: transforms::clamp(&image.inner, min_value, max_value)
.map_err(|e| to_py_err(e, "clamp"))?,
})
}
#[pyfunction]
fn crop_or_pad(image: &PyNiftiImage, target_shape: Vec<usize>) -> PyResult<PyNiftiImage> {
if target_shape.len() != 3 {
return Err(PyValueError::new_err(
"target_shape must be a 3-element sequence",
));
}
Ok(PyNiftiImage {
inner: transforms::crop_or_pad(&image.inner, &target_shape)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, target_spacing, method=None))]
fn resample(
image: &PyNiftiImage,
target_spacing: (f32, f32, f32),
method: Option<&str>,
) -> PyResult<PyNiftiImage> {
let interp = match method.unwrap_or("trilinear") {
"trilinear" | "linear" => Interpolation::Trilinear,
"nearest" => Interpolation::Nearest,
other => {
return Err(PyValueError::new_err(format!(
"method must be 'trilinear' or 'nearest', got {}",
other
)))
}
};
let spacing = [target_spacing.0, target_spacing.1, target_spacing.2];
let resampled = transforms::resample_to_spacing(&image.inner, spacing, interp)
.map_err(|e| PyValueError::new_err(format!("Resampling failed: {}", e)))?;
Ok(PyNiftiImage { inner: resampled })
}
#[pyfunction]
fn reorient(image: &PyNiftiImage, orientation: &str) -> PyResult<PyNiftiImage> {
let target: Orientation = orientation
.parse()
.map_err(|e| PyValueError::new_err(format!("{}", e)))?;
Ok(PyNiftiImage {
inner: transforms::reorient(&image.inner, target)
.map_err(|e| PyValueError::new_err(format!("Reorientation failed: {}", e)))?,
})
}
#[pyfunction]
fn load(path: &str) -> PyResult<PyNiftiImage> {
let validated_path = validate_file_path(path, "load")?;
let path_str = validated_path
.to_str()
.ok_or_else(|| PyValueError::new_err("path contains invalid UTF-8"))?;
nifti::load(path_str)
.map(|inner| PyNiftiImage { inner })
.map_err(|e| to_py_err(e, &format!("Failed to load {}", path)))
}
#[pyfunction]
#[pyo3(signature = (path, dtype=None, device="cpu"))]
fn load_to_torch(
py: Python<'_>,
path: &str,
dtype: Option<PyObject>,
device: &str,
) -> PyResult<PyObject> {
let img = nifti::load(path).map_err(|e| to_py_err(e, &format!("Failed to load {}", path)))?;
let py_img = PyNiftiImage { inner: img };
py_img.to_torch_with_dtype_and_device(py, dtype, Some(device))
}
#[pyfunction]
fn load_cropped(
path: &str,
crop_offset: [usize; 3],
crop_shape: [usize; 3],
) -> PyResult<PyNiftiImage> {
nifti::load_cropped(path, crop_offset, crop_shape)
.map(|inner| PyNiftiImage { inner })
.map_err(|e| to_py_err(e, &format!("Failed to load cropped {}", path)))
}
#[pyfunction]
#[pyo3(signature = (path, output_shape, target_spacing=None, target_orientation=None, output_offset=None))]
fn load_resampled(
path: &str,
output_shape: [usize; 3],
target_spacing: Option<[f32; 3]>,
target_orientation: Option<String>,
output_offset: Option<[usize; 3]>,
) -> PyResult<PyNiftiImage> {
use crate::transforms::Orientation;
let orientation = match target_orientation {
Some(s) => Some(
s.parse::<Orientation>()
.map_err(|e| PyValueError::new_err(format!("{}", e)))?,
),
None => None,
};
let config = nifti::LoadCroppedConfig {
output_shape,
target_spacing,
target_orientation: orientation,
output_offset,
};
nifti::load_cropped_config(path, config)
.map(|inner| PyNiftiImage { inner })
.map_err(|e| to_py_err(e, &format!("Failed to load cropped {}", path)))
}
#[pyfunction]
#[pyo3(signature = (path, output_shape, target_spacing=None, target_orientation=None, output_offset=None, dtype=None, device="cpu"))]
#[allow(clippy::too_many_arguments)]
fn load_cropped_to_torch(
py: Python<'_>,
path: &str,
output_shape: [usize; 3],
target_spacing: Option<[f32; 3]>,
target_orientation: Option<String>,
output_offset: Option<[usize; 3]>,
dtype: Option<PyObject>,
device: &str,
) -> PyResult<PyObject> {
use crate::transforms::Orientation;
validate_shape(&output_shape, "output_shape")?;
if let Some(ref spacing) = target_spacing {
validate_spacing(spacing, "target_spacing")?;
}
let orientation = match target_orientation {
Some(s) => Some(
s.parse::<Orientation>()
.map_err(|e| PyValueError::new_err(format!("{}", e)))?,
),
None => None,
};
let config = nifti::LoadCroppedConfig {
output_shape,
target_spacing,
target_orientation: orientation,
output_offset,
};
let img = nifti::load_cropped_config(path, config)
.map_err(|e| to_py_err(e, &format!("Failed to load cropped {}", path)))?;
let py_img = PyNiftiImage { inner: img };
py_img.to_torch_with_dtype_and_device(py, dtype, Some(device))
}
#[pyfunction]
#[pyo3(signature = (path, output_shape, target_spacing=None, target_orientation=None, output_offset=None, dtype=None, device="cpu"))]
#[allow(clippy::too_many_arguments)]
fn load_cropped_to_jax(
py: Python<'_>,
path: &str,
output_shape: [usize; 3],
target_spacing: Option<[f32; 3]>,
target_orientation: Option<String>,
output_offset: Option<[usize; 3]>,
dtype: Option<PyObject>,
device: &str,
) -> PyResult<PyObject> {
use crate::transforms::Orientation;
validate_shape(&output_shape, "output_shape")?;
if let Some(ref spacing) = target_spacing {
validate_spacing(spacing, "target_spacing")?;
}
let orientation = match target_orientation {
Some(s) => Some(
s.parse::<Orientation>()
.map_err(|e| PyValueError::new_err(format!("{}", e)))?,
),
None => None,
};
let config = nifti::LoadCroppedConfig {
output_shape,
target_spacing,
target_orientation: orientation,
output_offset,
};
let img = nifti::load_cropped_config(path, config)
.map_err(|e| to_py_err(e, &format!("Failed to load cropped {}", path)))?;
let py_img = PyNiftiImage { inner: img };
py_img.to_jax_with_dtype_and_device(py, dtype, Some(device))
}
#[pyclass(name = "TrainingDataLoader")]
pub struct PyTrainingDataLoader {
loader: nifti::TrainingDataLoader,
}
#[pymethods]
impl PyTrainingDataLoader {
#[new]
#[pyo3(signature = (volumes, patch_size, patches_per_volume, patch_overlap, randomize, cache_size=None))]
fn new(
volumes: Vec<String>,
patch_size: [usize; 3],
patches_per_volume: usize,
patch_overlap: [usize; 3],
randomize: bool,
cache_size: Option<usize>,
) -> PyResult<Self> {
for i in 0..3 {
if patch_overlap[i] >= patch_size[i] {
return Err(PyValueError::new_err(
"patch_overlap must be smaller than patch_size in all dimensions",
));
}
}
let config = nifti::CropLoaderConfig {
patch_size,
patches_per_volume,
patch_overlap,
randomize,
};
let cache_size = cache_size.unwrap_or(1000);
let loader = nifti::TrainingDataLoader::new(volumes, config, cache_size)
.map_err(|e| PyValueError::new_err(format!("Failed to create loader: {}", e)))?;
Ok(Self { loader })
}
fn next_patch(&mut self) -> PyResult<PyNiftiImage> {
match self.loader.next_patch() {
Ok(inner) => Ok(PyNiftiImage { inner }),
Err(MedrsError::Exhausted(msg)) => Err(PyStopIteration::new_err(msg)),
Err(e) => Err(to_py_err(e, "next_patch")),
}
}
fn __len__(&self) -> usize {
self.loader.volumes_len() * self.loader.patches_per_volume()
}
fn __iter__(mut slf: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
slf.loader
.reset()
.map_err(|e| PyValueError::new_err(format!("Failed to reset loader: {}", e)))?;
Ok(slf)
}
fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<Option<PyNiftiImage>> {
match slf.loader.next_patch() {
Ok(img) => Ok(Some(PyNiftiImage { inner: img })),
Err(MedrsError::Exhausted(_)) => Ok(None),
Err(e) => Err(to_py_err(e, "iterator")),
}
}
fn stats(&self) -> String {
format!("{}", self.loader.stats())
}
fn reset(&mut self) -> PyResult<()> {
self.loader
.reset()
.map_err(|e| PyValueError::new_err(format!("Failed to reset loader: {}", e)))
}
}
#[pyclass(name = "TransformPipeline")]
pub struct PyTransformPipeline {
inner: RustTransformPipeline,
}
#[pymethods]
impl PyTransformPipeline {
#[new]
#[pyo3(signature = (lazy=true))]
fn new(lazy: bool) -> Self {
let inner = if lazy {
RustTransformPipeline::new()
} else {
RustTransformPipeline::new().lazy(false)
};
Self { inner }
}
fn z_normalize(mut self_: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).z_normalize();
self_
}
fn rescale(mut self_: PyRefMut<'_, Self>, out_min: f32, out_max: f32) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).rescale(out_min, out_max);
self_
}
fn clamp(mut self_: PyRefMut<'_, Self>, min: f32, max: f32) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).clamp(min, max);
self_
}
fn resample_to_spacing(mut self_: PyRefMut<'_, Self>, spacing: [f32; 3]) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).resample_to_spacing(spacing);
self_
}
fn resample_to_shape(mut self_: PyRefMut<'_, Self>, shape: [usize; 3]) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).resample_to_shape(shape);
self_
}
fn flip(mut self_: PyRefMut<'_, Self>, axes: Vec<usize>) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).flip(&axes);
self_
}
fn set_lazy(mut self_: PyRefMut<'_, Self>, lazy: bool) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).lazy(lazy);
self_
}
fn apply(&self, image: &PyNiftiImage) -> PyResult<PyNiftiImage> {
let result = self
.inner
.apply(&image.inner)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(PyNiftiImage { inner: result })
}
#[allow(clippy::unused_self)]
fn __repr__(&self) -> String {
"TransformPipeline(...)".to_string()
}
}
#[pyfunction]
#[pyo3(signature = (image, axes, prob=None, seed=None))]
fn random_flip(
image: &PyNiftiImage,
axes: Vec<usize>,
prob: Option<f32>,
seed: Option<u64>,
) -> PyResult<PyNiftiImage> {
if let Some(p) = prob {
validate_probability(p as f64, "random_flip")?;
}
for &axis in &axes {
if axis >= 3 {
return Err(PyValueError::new_err(format!(
"random_flip: axis {} is out of range (must be 0, 1, or 2)",
axis
)));
}
}
Ok(PyNiftiImage {
inner: transforms::random_flip(&image.inner, &axes, prob, seed)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, std=None, seed=None))]
fn random_gaussian_noise(
image: &PyNiftiImage,
std: Option<f32>,
seed: Option<u64>,
) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::random_gaussian_noise(&image.inner, std, seed)
.map_err(|e| to_py_err(e, "random_gaussian_noise"))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, scale_range=None, seed=None))]
fn random_intensity_scale(
image: &PyNiftiImage,
scale_range: Option<f32>,
seed: Option<u64>,
) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::random_intensity_scale(&image.inner, scale_range, seed)
.map_err(|e| to_py_err(e, "random_intensity_scale"))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, shift_range=None, seed=None))]
fn random_intensity_shift(
image: &PyNiftiImage,
shift_range: Option<f32>,
seed: Option<u64>,
) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::random_intensity_shift(&image.inner, shift_range, seed)
.map_err(|e| to_py_err(e, "random_intensity_shift"))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, axes, seed=None))]
fn random_rotate_90(
image: &PyNiftiImage,
axes: (usize, usize),
seed: Option<u64>,
) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::random_rotate_90(&image.inner, axes, seed)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, gamma_range=None, seed=None))]
fn random_gamma(
image: &PyNiftiImage,
gamma_range: Option<(f32, f32)>,
seed: Option<u64>,
) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::random_gamma(&image.inner, gamma_range, seed)
.map_err(|e| to_py_err(e, "random_gamma"))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, seed=None))]
fn random_augment(image: &PyNiftiImage, seed: Option<u64>) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::random_augment(&image.inner, seed)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
#[pymodule]
fn _medrs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyNiftiImage>()?;
m.add_class::<PyTrainingDataLoader>()?;
m.add_class::<PyTransformPipeline>()?;
m.add_function(wrap_pyfunction!(z_normalization, m)?)?;
m.add_function(wrap_pyfunction!(rescale_intensity, m)?)?;
m.add_function(wrap_pyfunction!(clamp, m)?)?;
m.add_function(wrap_pyfunction!(crop_or_pad, m)?)?;
m.add_function(wrap_pyfunction!(resample, m)?)?;
m.add_function(wrap_pyfunction!(reorient, m)?)?;
m.add_function(wrap_pyfunction!(load, m)?)?;
m.add_function(wrap_pyfunction!(load_to_torch, m)?)?;
m.add_function(wrap_pyfunction!(load_cropped, m)?)?;
m.add_function(wrap_pyfunction!(load_resampled, m)?)?;
m.add_function(wrap_pyfunction!(load_cropped_to_torch, m)?)?;
m.add_function(wrap_pyfunction!(load_cropped_to_jax, m)?)?;
m.add_function(wrap_pyfunction!(load_label_aware_cropped, m)?)?;
m.add_function(wrap_pyfunction!(compute_crop_regions, m)?)?;
m.add_function(wrap_pyfunction!(compute_random_spatial_crops, m)?)?;
m.add_function(wrap_pyfunction!(compute_center_crop, m)?)?;
m.add_function(wrap_pyfunction!(random_flip, m)?)?;
m.add_function(wrap_pyfunction!(random_gaussian_noise, m)?)?;
m.add_function(wrap_pyfunction!(random_intensity_scale, m)?)?;
m.add_function(wrap_pyfunction!(random_intensity_shift, m)?)?;
m.add_function(wrap_pyfunction!(random_rotate_90, m)?)?;
m.add_function(wrap_pyfunction!(random_gamma, m)?)?;
m.add_function(wrap_pyfunction!(random_augment, m)?)?;
Ok(())
}
fn arraydata_to_numpy(py: Python<'_>, data: &ArrayData, shape: &[usize]) -> PyResult<PyObject> {
let dyn_shape = IxDyn(shape);
Ok(match data {
ArrayData::U8(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::I8(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::I16(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::U16(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::I32(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::U32(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::I64(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::U64(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::F16(a) => a
.mapv(|v| v.to_f32())
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::BF16(a) => a
.mapv(|v| v.to_f32())
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::F32(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::F64(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
})
}
fn to_numpy_view<'py>(
py: Python<'py>,
image: &RustNiftiImage,
) -> Option<Bound<'py, PyArrayDyn<f32>>> {
if let Some(view) = image.as_view_f32() {
let arr = PyArrayDyn::from_array(py, &view);
return Some(arr);
}
None
}
fn to_numpy_array<'py>(
py: Python<'py>,
image: &RustNiftiImage,
) -> PyResult<Bound<'py, PyArrayDyn<f32>>> {
let data = image.to_f32().map_err(|e| to_py_err(e, "to_f32"))?;
let shape: Vec<usize> = data.shape().to_vec();
let slice = data
.as_slice_memory_order()
.ok_or_else(|| PyValueError::new_err("Array is not contiguous in memory"))?;
let np = py.import("numpy")?;
let flat = slice.to_vec().into_pyarray(py);
let kwargs = [("order", "F")].into_py_dict(py)?;
let reshaped = np
.call_method("reshape", (flat, &shape), Some(&kwargs))
.map_err(|e| PyValueError::new_err(format!("Failed to reshape array: {}", e)))?;
reshaped
.extract::<Bound<'py, PyArrayDyn<f32>>>()
.map_err(|e| PyValueError::new_err(format!("Failed to extract array: {}", e)))
}
fn to_numpy_view_native(py: Python<'_>, image: &RustNiftiImage) -> Option<PyObject> {
let arr_obj = match image.dtype() {
DataType::UInt8 => image
.as_view_t::<u8>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Int8 => image
.as_view_t::<i8>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Int16 => image
.as_view_t::<i16>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::UInt16 => image
.as_view_t::<u16>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Int32 => image
.as_view_t::<i32>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::UInt32 => image
.as_view_t::<u32>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Int64 => image
.as_view_t::<i64>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::UInt64 => image
.as_view_t::<u64>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Float16 | DataType::BFloat16 => None,
DataType::Float32 => image
.as_view_t::<f32>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Float64 => image
.as_view_t::<f64>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
};
arr_obj
}
fn torch_dtype(py: Python<'_>, dtype: DataType) -> Option<PyObject> {
let torch = py.import("torch").ok()?;
let dt = match dtype {
DataType::UInt8 => "uint8",
DataType::Int8 => "int8",
DataType::Int16 => "int16",
DataType::Int32 => "int32",
DataType::Int64 => "int64",
DataType::Float16 => "float16",
DataType::BFloat16 => "bfloat16",
DataType::Float32 => "float32",
DataType::Float64 => "float64",
DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => return None,
};
torch.getattr(dt).ok().map(|o| o.unbind())
}
#[pyfunction]
#[pyo3(signature = (image_path, label_path, patch_size, pos_neg_ratio=None, min_pos_samples=None, seed=None))]
fn load_label_aware_cropped(
_py: Python<'_>,
image_path: &str,
label_path: &str,
patch_size: Vec<usize>,
pos_neg_ratio: Option<f64>,
min_pos_samples: Option<usize>,
seed: Option<u64>,
) -> PyResult<(PyNiftiImage, PyNiftiImage)> {
let patch_size = parse_shape3(&patch_size, "patch_size")?;
let image = nifti::load(image_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load image: {}", e))
})?;
let label = nifti::load(label_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load label: {}", e))
})?;
let config = transforms::RandCropByPosNegLabelConfig {
patch_size,
pos_neg_ratio: pos_neg_ratio.unwrap_or(1.0) as f32,
min_pos_samples: min_pos_samples.unwrap_or(4),
seed,
background_label: 0.0,
};
let crop_regions = compute_label_aware_crop_regions(&config, &image, &label, 1)
.map_err(|e| to_py_err(e, "compute_label_aware_crop_regions"))?;
if crop_regions.is_empty() {
return Err(PyValueError::new_err("No valid crop regions found"));
}
let region = &crop_regions[0];
let cropped_image =
nifti::load_cropped(image_path, region.start, region.size).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Failed to load cropped image: {}",
e
))
})?;
let cropped_label =
nifti::load_cropped(label_path, region.start, region.size).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Failed to load cropped label: {}",
e
))
})?;
Ok((
PyNiftiImage {
inner: cropped_image,
},
PyNiftiImage {
inner: cropped_label,
},
))
}
#[pyfunction]
#[pyo3(signature = (image_path, label_path, patch_size, num_samples, pos_neg_ratio=None, min_pos_samples=None, seed=None))]
#[allow(clippy::too_many_arguments)]
fn compute_crop_regions(
py: Python<'_>,
image_path: &str,
label_path: &str,
patch_size: Vec<usize>,
num_samples: usize,
pos_neg_ratio: Option<f64>,
min_pos_samples: Option<usize>,
seed: Option<u64>,
) -> PyResult<Vec<PyObject>> {
let patch_size = parse_shape3(&patch_size, "patch_size")?;
let image = nifti::load(image_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load image: {}", e))
})?;
let label = nifti::load(label_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load label: {}", e))
})?;
let config = transforms::RandCropByPosNegLabelConfig {
patch_size,
pos_neg_ratio: pos_neg_ratio.unwrap_or(1.0) as f32,
min_pos_samples: min_pos_samples.unwrap_or(4),
seed,
background_label: 0.0,
};
let crop_regions = compute_label_aware_crop_regions(&config, &image, &label, num_samples)
.map_err(|e| to_py_err(e, "compute_label_aware_crop_regions"))?;
let mut regions_py = Vec::new();
for region in crop_regions {
let region_dict = pyo3::types::PyDict::new(py);
region_dict.set_item("start", region.start.to_vec())?;
region_dict.set_item("end", region.end.to_vec())?;
region_dict.set_item("size", region.size.to_vec())?;
regions_py.push(region_dict.into());
}
Ok(regions_py)
}
#[pyfunction]
#[pyo3(signature = (image_path, patch_size, num_samples, seed=None, allow_smaller=None))]
fn compute_random_spatial_crops(
py: Python<'_>,
image_path: &str,
patch_size: Vec<usize>,
num_samples: usize,
seed: Option<u64>,
allow_smaller: Option<bool>,
) -> PyResult<Vec<PyObject>> {
let patch_size = parse_shape3(&patch_size, "patch_size")?;
let image = nifti::load(image_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load image: {}", e))
})?;
let config = transforms::SpatialCropConfig {
patch_size,
seed,
allow_smaller: allow_smaller.unwrap_or(false),
};
let crop_regions = compute_random_spatial_crop_regions(&config, &image, num_samples);
let mut regions_py = Vec::new();
for region in crop_regions {
let region_dict = pyo3::types::PyDict::new(py);
region_dict.set_item("start", region.start.to_vec())?;
region_dict.set_item("end", region.end.to_vec())?;
region_dict.set_item("size", region.size.to_vec())?;
regions_py.push(region_dict.into());
}
Ok(regions_py)
}
#[pyfunction]
fn compute_center_crop(
py: Python<'_>,
image_path: &str,
patch_size: Vec<usize>,
) -> PyResult<PyObject> {
let patch_size = parse_shape3(&patch_size, "patch_size")?;
let image = nifti::load(image_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load image: {}", e))
})?;
let region = compute_center_crop_regions(patch_size, &image);
let region_dict = pyo3::types::PyDict::new(py);
region_dict.set_item("start", region.start.to_vec())?;
region_dict.set_item("end", region.end.to_vec())?;
region_dict.set_item("size", region.size.to_vec())?;
Ok(region_dict.into())
}