#![allow(clippy::doc_markdown)]
use ndarray::{ArrayD, IxDyn};
use numpy::{IntoPyArray, PyArrayDyn, PyArrayMethods, PyReadonlyArrayDyn};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use crate::nifti::image::ArrayData;
use crate::nifti::DataType;
use crate::nifti::{self, NiftiImage as RustNiftiImage};
use crate::pipeline::TransformPipeline as RustTransformPipeline;
use crate::transforms::crop::{
compute_center_crop_regions, compute_label_aware_crop_regions,
compute_random_spatial_crop_regions,
};
use crate::transforms::{self, Interpolation, Orientation};
#[pyclass(name = "NiftiImage")]
struct PyNiftiImage {
inner: RustNiftiImage,
}
#[pymethods]
impl PyNiftiImage {
#[new]
#[pyo3(signature = (data, affine=None))]
fn new<'py>(
data: PyReadonlyArrayDyn<'py, f32>,
affine: Option<[[f32; 4]; 4]>,
) -> PyResult<Self> {
let arr = data.as_array();
let shape = arr.shape();
if shape.len() < 3 {
return Err(PyValueError::new_err(
"Array must have at least 3 dimensions (D,H,W)",
));
}
let data_vec: Vec<f32> = arr.iter().copied().collect();
let array = ArrayD::from_shape_vec(shape.to_vec(), data_vec)
.map_err(|e| PyValueError::new_err(format!("Invalid array shape: {}", e)))?;
let affine = affine.unwrap_or([
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]);
Ok(Self {
inner: RustNiftiImage::from_array(array, affine),
})
}
#[getter]
fn shape(&self) -> Vec<usize> {
self.inner.shape().to_vec()
}
#[getter]
fn ndim(&self) -> usize {
self.inner.ndim()
}
#[getter]
fn dtype(&self) -> &'static str {
self.inner.dtype().type_name()
}
#[getter]
fn spacing(&self) -> Vec<f32> {
self.inner.spacing().to_vec()
}
#[getter]
fn affine(&self) -> [[f32; 4]; 4] {
self.inner.affine()
}
#[setter]
fn set_affine(&mut self, affine: [[f32; 4]; 4]) {
self.inner.set_affine(affine);
}
#[getter]
fn orientation(&self) -> String {
let affine = self.inner.affine();
transforms::orientation_from_affine(&affine).to_string()
}
#[getter]
fn data<'py>(&self, py: Python<'py>) -> Bound<'py, PyArrayDyn<f32>> {
self.to_numpy(py)
}
fn to_numpy<'py>(&self, py: Python<'py>) -> Bound<'py, PyArrayDyn<f32>> {
if let Some(arr) = to_numpy_view(py, &self.inner) {
return arr;
}
to_numpy_array(py, &self.inner)
}
fn to_numpy_view<'py>(&self, py: Python<'py>) -> Bound<'py, PyArrayDyn<f32>> {
if let Some(arr) = to_numpy_view(py, &self.inner) {
arr
} else {
to_numpy_array(py, &self.inner)
}
}
fn to_torch<'py>(&self, py: Python<'py>) -> PyResult<PyObject> {
let torch = py.import("torch")?;
let dtype = torch_dtype(py, self.inner.dtype());
let has_dtype = dtype.is_some();
let np_obj = if has_dtype {
if let Some(np_view) = to_numpy_view_native(py, &self.inner) {
np_view
} else {
arraydata_to_numpy(py, &self.inner.owned_data(), self.inner.shape())?
}
} else {
to_numpy_array(py, &self.inner)
.into_pyobject(py)?
.into_any()
.unbind()
};
let tensor = torch.getattr("from_numpy")?.call1((np_obj,))?;
if let Some(dt) = dtype {
let tensor = tensor.call_method1("to", (dt,))?;
Ok(tensor.unbind())
} else {
Ok(tensor.unbind())
}
}
fn to_jax<'py>(&self, py: Python<'py>) -> PyResult<PyObject> {
let jnp = py.import("jax.numpy")?;
let np_obj = to_numpy_array(py, &self.inner);
let arr = jnp.getattr("array")?.call1((np_obj,))?;
Ok(arr.unbind())
}
#[pyo3(signature = (dtype=None, device=None))]
fn to_torch_with_dtype_and_device<'py>(
&self,
py: Python<'py>,
dtype: Option<PyObject>,
device: Option<&str>,
) -> PyResult<PyObject> {
let device_str = device.unwrap_or("cpu");
let torch = py.import("torch")?;
let np_obj = if let Some(np_view) = to_numpy_view_native(py, &self.inner) {
np_view
} else {
to_numpy_array(py, &self.inner)
.into_pyobject(py)?
.into_any()
.unbind()
};
let tensor = torch.getattr("from_numpy")?.call1((np_obj,))?;
let tensor = tensor.call_method1("to", (device_str,))?;
if let Some(dt) = dtype {
let tensor = tensor.call_method1("to", (dt,))?;
Ok(tensor.unbind())
} else {
Ok(tensor.unbind())
}
}
#[pyo3(signature = (dtype=None, device=None))]
fn to_jax_with_dtype_and_device<'py>(
&self,
py: Python<'py>,
dtype: Option<PyObject>,
device: Option<&str>,
) -> PyResult<PyObject> {
let device_str = device.unwrap_or("cpu");
let jax = py.import("jax")?;
let _jnp = py.import("jax.numpy")?;
let device_obj = if device_str == "cpu" {
jax.getattr("devices")?.call1((0,))?
} else if device_str.starts_with("cuda") {
let cuda_devices = jax.getattr("devices")?.call1(("cuda",))?;
if device_str == "cuda" {
cuda_devices.get_item(0)?
} else {
let device_id: usize = device_str
.strip_prefix("cuda:")
.unwrap_or("0")
.parse()
.unwrap_or(0);
cuda_devices.get_item(device_id)?
}
} else {
return Err(PyValueError::new_err(format!(
"Unsupported device: {}",
device_str
)));
};
let jax = py.import("jax")?;
let jnp = py.import("jax.numpy")?;
if let Some(np_view) = to_numpy_view_native(py, &self.inner) {
let mut arr = jnp.getattr("array")?.call1((np_view,))?;
if let Some(dt) = dtype {
arr = jnp.getattr("astype")?.call1((dt,))?;
}
let device_put = jax.getattr("device_put")?;
let arr = device_put.call1((arr, &device_obj))?;
return Ok(arr.into());
}
let np_obj = to_numpy_array(py, &self.inner);
let mut arr = jnp.getattr("array")?.call1((np_obj,))?;
if let Some(dt) = dtype {
arr = jnp.getattr("astype")?.call1((dt,))?;
}
let device_put = jax.getattr("device_put")?;
let arr = device_put.call1((arr, &device_obj))?;
Ok(arr.into())
}
fn to_numpy_native<'py>(&self, py: Python<'py>) -> PyResult<PyObject> {
Ok(arraydata_to_numpy(
py,
&self.inner.owned_data(),
self.inner.shape(),
)?)
}
#[staticmethod]
#[pyo3(signature = (data, affine=None))]
fn from_numpy<'py>(
_py: Python<'py>,
data: PyReadonlyArrayDyn<'py, f32>,
affine: Option<[[f32; 4]; 4]>,
) -> PyResult<Self> {
use ndarray::ShapeBuilder;
let arr = data.as_array();
let shape = arr.shape();
if shape.len() < 3 {
return Err(PyValueError::new_err(
"Array must have at least 3 dimensions (D,H,W)",
));
}
let data_vec: Vec<f32> = if let Some(slice) = arr.as_slice_memory_order() {
slice.to_vec()
} else {
arr.iter().copied().collect()
};
let is_f_order = arr.is_standard_layout() == false && arr.as_slice_memory_order().is_some();
let array = if is_f_order {
ArrayD::from_shape_vec(ndarray::IxDyn(shape).f(), data_vec)
.map_err(|e| PyValueError::new_err(format!("Invalid array shape: {}", e)))?
} else {
let c_order = ArrayD::from_shape_vec(shape.to_vec(), data_vec)
.map_err(|e| PyValueError::new_err(format!("Invalid array shape: {}", e)))?;
let mut f_order = ArrayD::zeros(ndarray::IxDyn(shape).f());
f_order.assign(&c_order);
f_order
};
let affine = affine.unwrap_or([
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]);
Ok(Self {
inner: RustNiftiImage::from_array(array, affine),
})
}
fn save(&self, path: &str) -> PyResult<()> {
nifti::save(&self.inner, path).map_err(|e| PyValueError::new_err(e.to_string()))
}
#[pyo3(signature = (spacing, method=None))]
fn resample(&self, spacing: [f32; 3], method: Option<&str>) -> PyResult<Self> {
let method_str = method.unwrap_or("trilinear");
let interp = match method_str {
"trilinear" | "linear" => Interpolation::Trilinear,
"nearest" => Interpolation::Nearest,
_ => {
return Err(PyValueError::new_err(
"method must be 'trilinear' or 'nearest'",
))
}
};
Ok(Self {
inner: transforms::resample_to_spacing(&self.inner, spacing, interp),
})
}
#[pyo3(signature = (shape, method=None))]
fn resample_to_shape(&self, shape: [usize; 3], method: Option<&str>) -> PyResult<Self> {
let method_str = method.unwrap_or("trilinear");
let interp = match method_str {
"trilinear" | "linear" => Interpolation::Trilinear,
"nearest" => Interpolation::Nearest,
_ => {
return Err(PyValueError::new_err(
"method must be 'trilinear' or 'nearest'",
))
}
};
Ok(Self {
inner: transforms::resample_to_shape(&self.inner, shape, interp),
})
}
fn reorient(&self, orientation: &str) -> PyResult<Self> {
let target = Orientation::from_str(orientation).ok_or_else(|| {
PyValueError::new_err(format!("Invalid orientation code: {}. Valid codes: RAS, LAS, LPI, RPI, ASL, ARS, ALI, ARI, IPL, SPR, IAR, SAR, IPR, SPL, IAL, SAL", orientation))
})?;
Ok(Self {
inner: transforms::reorient(&self.inner, target),
})
}
fn z_normalize(&self) -> Self {
Self {
inner: transforms::z_normalization(&self.inner),
}
}
fn rescale(&self, out_min: f64, out_max: f64) -> Self {
Self {
inner: transforms::rescale_intensity(&self.inner, out_min, out_max),
}
}
fn clamp(&self, min: f64, max: f64) -> Self {
Self {
inner: transforms::clamp(&self.inner, min, max),
}
}
fn crop_or_pad(&self, target_shape: Vec<usize>) -> PyResult<Self> {
Ok(Self {
inner: transforms::crop_or_pad(&self.inner, &target_shape)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
fn flip(&self, axes: Vec<usize>) -> PyResult<Self> {
Ok(Self {
inner: transforms::flip(&self.inner, &axes)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
fn is_materialized(&self) -> bool {
self.inner.is_materialized()
}
fn materialize(&self) -> Self {
Self {
inner: self.inner.materialize(),
}
}
fn __repr__(&self) -> String {
format!(
"NiftiImage(shape={:?}, dtype={}, spacing={:?}, orientation={})",
self.shape(),
self.dtype(),
self.spacing(),
self.orientation()
)
}
}
#[pyfunction]
fn z_normalization(image: &PyNiftiImage) -> PyNiftiImage {
PyNiftiImage {
inner: transforms::z_normalization(&image.inner),
}
}
#[pyfunction]
#[pyo3(signature = (image, output_range=(0.0, 1.0)))]
fn rescale_intensity(image: &PyNiftiImage, output_range: (f64, f64)) -> PyNiftiImage {
let (out_min, out_max) = output_range;
PyNiftiImage {
inner: transforms::rescale_intensity(&image.inner, out_min, out_max),
}
}
#[pyfunction]
fn clamp(image: &PyNiftiImage, min_value: f64, max_value: f64) -> PyNiftiImage {
PyNiftiImage {
inner: transforms::clamp(&image.inner, min_value, max_value),
}
}
#[pyfunction]
fn crop_or_pad(image: &PyNiftiImage, target_shape: Vec<usize>) -> PyResult<PyNiftiImage> {
if target_shape.len() != 3 {
return Err(PyValueError::new_err(
"target_shape must be a 3-element sequence",
));
}
Ok(PyNiftiImage {
inner: transforms::crop_or_pad(&image.inner, &target_shape)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, target_spacing, method=None))]
fn resample(
image: &PyNiftiImage,
target_spacing: (f32, f32, f32),
method: Option<&str>,
) -> PyResult<PyNiftiImage> {
let interp = match method.unwrap_or("trilinear") {
"trilinear" | "linear" => Interpolation::Trilinear,
"nearest" => Interpolation::Nearest,
other => {
return Err(PyValueError::new_err(format!(
"method must be 'trilinear' or 'nearest', got {}",
other
)))
}
};
let spacing = [target_spacing.0, target_spacing.1, target_spacing.2];
Ok(PyNiftiImage {
inner: transforms::resample_to_spacing(&image.inner, spacing, interp),
})
}
#[pyfunction]
fn reorient(image: &PyNiftiImage, orientation: &str) -> PyResult<PyNiftiImage> {
let target = Orientation::from_str(orientation).ok_or_else(|| {
PyValueError::new_err(format!(
"Invalid orientation code: {}. Expected three-letter code like RAS or LPS",
orientation
))
})?;
Ok(PyNiftiImage {
inner: transforms::reorient(&image.inner, target),
})
}
#[pyfunction]
fn load(path: &str) -> PyResult<PyNiftiImage> {
nifti::load(path)
.map(|inner| PyNiftiImage { inner })
.map_err(|e| PyValueError::new_err(format!("Failed to load {}: {}", path, e)))
}
#[pyfunction]
#[pyo3(signature = (path, dtype=None, device="cpu"))]
fn load_to_torch(
py: Python<'_>,
path: &str,
dtype: Option<PyObject>,
device: &str,
) -> PyResult<PyObject> {
let img = nifti::load(path)
.map_err(|e| PyValueError::new_err(format!("Failed to load {}: {}", path, e)))?;
let py_img = PyNiftiImage { inner: img };
py_img.to_torch_with_dtype_and_device(py, dtype, Some(device))
}
#[pyfunction]
fn load_cropped(
path: &str,
crop_offset: [usize; 3],
crop_shape: [usize; 3],
) -> PyResult<PyNiftiImage> {
nifti::load_cropped(path, crop_offset, crop_shape)
.map(|inner| PyNiftiImage { inner })
.map_err(|e| PyValueError::new_err(format!("Failed to load cropped {}: {}", path, e)))
}
#[pyfunction]
#[pyo3(signature = (path, output_shape, target_spacing=None, target_orientation=None, output_offset=None))]
fn load_resampled(
path: &str,
output_shape: [usize; 3],
target_spacing: Option<[f32; 3]>,
target_orientation: Option<String>,
output_offset: Option<[usize; 3]>,
) -> PyResult<PyNiftiImage> {
use crate::transforms::Orientation;
let orientation = match target_orientation {
Some(s) => Some(
Orientation::from_str(&s)
.ok_or_else(|| PyValueError::new_err(format!("Invalid orientation: {}", s)))?,
),
None => None,
};
let config = nifti::LoadCroppedConfig {
output_shape,
target_spacing,
target_orientation: orientation,
output_offset,
};
nifti::load_cropped_config(path, config)
.map(|inner| PyNiftiImage { inner })
.map_err(|e| PyValueError::new_err(format!("Failed to load cropped {}: {}", path, e)))
}
#[pyfunction]
#[pyo3(signature = (path, output_shape, target_spacing=None, target_orientation=None, output_offset=None, dtype=None, device="cpu"))]
fn load_cropped_to_torch(
py: Python<'_>,
path: &str,
output_shape: [usize; 3],
target_spacing: Option<[f32; 3]>,
target_orientation: Option<String>,
output_offset: Option<[usize; 3]>,
dtype: Option<PyObject>,
device: &str,
) -> PyResult<PyObject> {
use crate::transforms::Orientation;
let orientation = match target_orientation {
Some(s) => Some(
Orientation::from_str(&s)
.ok_or_else(|| PyValueError::new_err(format!("Invalid orientation: {}", s)))?,
),
None => None,
};
let config = nifti::LoadCroppedConfig {
output_shape,
target_spacing,
target_orientation: orientation,
output_offset,
};
let img = nifti::load_cropped_config(path, config)
.map_err(|e| PyValueError::new_err(format!("Failed to load cropped {}: {}", path, e)))?;
let py_img = PyNiftiImage { inner: img };
py_img.to_torch_with_dtype_and_device(py, dtype, Some(device))
}
#[pyfunction]
#[pyo3(signature = (path, output_shape, target_spacing=None, target_orientation=None, output_offset=None, dtype=None, device="cpu"))]
fn load_cropped_to_jax(
py: Python<'_>,
path: &str,
output_shape: [usize; 3],
target_spacing: Option<[f32; 3]>,
target_orientation: Option<String>,
output_offset: Option<[usize; 3]>,
dtype: Option<PyObject>,
device: &str,
) -> PyResult<PyObject> {
use crate::transforms::Orientation;
let orientation = match target_orientation {
Some(s) => Some(
Orientation::from_str(&s)
.ok_or_else(|| PyValueError::new_err(format!("Invalid orientation: {}", s)))?,
),
None => None,
};
let config = nifti::LoadCroppedConfig {
output_shape,
target_spacing,
target_orientation: orientation,
output_offset,
};
let img = nifti::load_cropped_config(path, config)
.map_err(|e| PyValueError::new_err(format!("Failed to load cropped {}: {}", path, e)))?;
let py_img = PyNiftiImage { inner: img };
py_img.to_jax_with_dtype_and_device(py, dtype, Some(device))
}
#[pyclass]
pub struct PyTrainingDataLoader {
loader: nifti::TrainingDataLoader,
}
#[pymethods]
impl PyTrainingDataLoader {
#[new]
#[pyo3(signature = (volumes, patch_size, patches_per_volume, patch_overlap, randomize, cache_size=None))]
fn new(
volumes: Vec<String>,
patch_size: [usize; 3],
patches_per_volume: usize,
patch_overlap: [usize; 3],
randomize: bool,
cache_size: Option<usize>,
) -> PyResult<Self> {
for i in 0..3 {
if patch_overlap[i] >= patch_size[i] {
return Err(PyValueError::new_err(
"patch_overlap must be smaller than patch_size in all dimensions",
));
}
}
let config = nifti::CropLoaderConfig {
patch_size,
patches_per_volume,
patch_overlap,
randomize,
};
let cache_size = cache_size.unwrap_or(1000);
let loader = nifti::TrainingDataLoader::new(volumes, config, cache_size)
.map_err(|e| PyValueError::new_err(format!("Failed to create loader: {}", e)))?;
Ok(Self { loader })
}
fn next_patch(&mut self) -> PyResult<PyNiftiImage> {
self.loader
.next_patch()
.map(|inner| PyNiftiImage { inner })
.map_err(|e| PyValueError::new_err(format!("Failed to get next patch: {}", e)))
}
fn __len__(&self) -> usize {
self.loader.volumes_len() * self.loader.patches_per_volume()
}
fn __iter__(mut slf: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
slf.loader
.reset()
.map_err(|e| PyValueError::new_err(format!("Failed to reset loader: {}", e)))?;
Ok(slf)
}
fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<PyNiftiImage> {
slf.next_patch().ok()
}
fn stats(&self) -> String {
format!("{}", self.loader.stats())
}
fn reset(&mut self) -> PyResult<()> {
self.loader
.reset()
.map_err(|e| PyValueError::new_err(format!("Failed to reset loader: {}", e)))
}
}
#[pyclass(name = "TransformPipeline")]
pub struct PyTransformPipeline {
inner: RustTransformPipeline,
}
#[pymethods]
impl PyTransformPipeline {
#[new]
#[pyo3(signature = (lazy=true))]
fn new(lazy: bool) -> Self {
let inner = if lazy {
RustTransformPipeline::new()
} else {
RustTransformPipeline::new().lazy(false)
};
Self { inner }
}
fn z_normalize(mut self_: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).z_normalize();
self_
}
fn rescale(mut self_: PyRefMut<'_, Self>, out_min: f32, out_max: f32) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).rescale(out_min, out_max);
self_
}
fn clamp(mut self_: PyRefMut<'_, Self>, min: f32, max: f32) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).clamp(min, max);
self_
}
fn resample_to_spacing(mut self_: PyRefMut<'_, Self>, spacing: [f32; 3]) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).resample_to_spacing(spacing);
self_
}
fn resample_to_shape(mut self_: PyRefMut<'_, Self>, shape: [usize; 3]) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).resample_to_shape(shape);
self_
}
fn flip(mut self_: PyRefMut<'_, Self>, axes: Vec<usize>) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).flip(&axes);
self_
}
fn set_lazy(mut self_: PyRefMut<'_, Self>, lazy: bool) -> PyRefMut<'_, Self> {
self_.inner = std::mem::take(&mut self_.inner).lazy(lazy);
self_
}
fn apply(&self, image: &PyNiftiImage) -> PyNiftiImage {
PyNiftiImage {
inner: self.inner.apply(&image.inner),
}
}
fn __repr__(&self) -> String {
"TransformPipeline(...)".to_string()
}
}
#[pyfunction]
#[pyo3(signature = (image, axes, prob=None, seed=None))]
fn random_flip(
image: &PyNiftiImage,
axes: Vec<usize>,
prob: Option<f32>,
seed: Option<u64>,
) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::random_flip(&image.inner, &axes, prob, seed)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, std=None, seed=None))]
fn random_gaussian_noise(
image: &PyNiftiImage,
std: Option<f32>,
seed: Option<u64>,
) -> PyNiftiImage {
PyNiftiImage {
inner: transforms::random_gaussian_noise(&image.inner, std, seed),
}
}
#[pyfunction]
#[pyo3(signature = (image, scale_range=None, seed=None))]
fn random_intensity_scale(
image: &PyNiftiImage,
scale_range: Option<f32>,
seed: Option<u64>,
) -> PyNiftiImage {
PyNiftiImage {
inner: transforms::random_intensity_scale(&image.inner, scale_range, seed),
}
}
#[pyfunction]
#[pyo3(signature = (image, shift_range=None, seed=None))]
fn random_intensity_shift(
image: &PyNiftiImage,
shift_range: Option<f32>,
seed: Option<u64>,
) -> PyNiftiImage {
PyNiftiImage {
inner: transforms::random_intensity_shift(&image.inner, shift_range, seed),
}
}
#[pyfunction]
#[pyo3(signature = (image, axes, seed=None))]
fn random_rotate_90(
image: &PyNiftiImage,
axes: (usize, usize),
seed: Option<u64>,
) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::random_rotate_90(&image.inner, axes, seed)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
#[pyfunction]
#[pyo3(signature = (image, gamma_range=None, seed=None))]
fn random_gamma(
image: &PyNiftiImage,
gamma_range: Option<(f32, f32)>,
seed: Option<u64>,
) -> PyNiftiImage {
PyNiftiImage {
inner: transforms::random_gamma(&image.inner, gamma_range, seed),
}
}
#[pyfunction]
#[pyo3(signature = (image, seed=None))]
fn random_augment(image: &PyNiftiImage, seed: Option<u64>) -> PyResult<PyNiftiImage> {
Ok(PyNiftiImage {
inner: transforms::random_augment(&image.inner, seed)
.map_err(|e| PyValueError::new_err(e.to_string()))?,
})
}
#[pymodule]
fn _medrs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyNiftiImage>()?;
m.add_class::<PyTrainingDataLoader>()?;
m.add_class::<PyTransformPipeline>()?;
m.add_function(wrap_pyfunction!(z_normalization, m)?)?;
m.add_function(wrap_pyfunction!(rescale_intensity, m)?)?;
m.add_function(wrap_pyfunction!(clamp, m)?)?;
m.add_function(wrap_pyfunction!(crop_or_pad, m)?)?;
m.add_function(wrap_pyfunction!(resample, m)?)?;
m.add_function(wrap_pyfunction!(reorient, m)?)?;
m.add_function(wrap_pyfunction!(load, m)?)?;
m.add_function(wrap_pyfunction!(load_to_torch, m)?)?;
m.add_function(wrap_pyfunction!(load_cropped, m)?)?;
m.add_function(wrap_pyfunction!(load_resampled, m)?)?;
m.add_function(wrap_pyfunction!(load_cropped_to_torch, m)?)?;
m.add_function(wrap_pyfunction!(load_cropped_to_jax, m)?)?;
m.add_function(wrap_pyfunction!(load_label_aware_cropped, m)?)?;
m.add_function(wrap_pyfunction!(compute_crop_regions, m)?)?;
m.add_function(wrap_pyfunction!(compute_random_spatial_crops, m)?)?;
m.add_function(wrap_pyfunction!(compute_center_crop, m)?)?;
m.add_function(wrap_pyfunction!(random_flip, m)?)?;
m.add_function(wrap_pyfunction!(random_gaussian_noise, m)?)?;
m.add_function(wrap_pyfunction!(random_intensity_scale, m)?)?;
m.add_function(wrap_pyfunction!(random_intensity_shift, m)?)?;
m.add_function(wrap_pyfunction!(random_rotate_90, m)?)?;
m.add_function(wrap_pyfunction!(random_gamma, m)?)?;
m.add_function(wrap_pyfunction!(random_augment, m)?)?;
Ok(())
}
fn arraydata_to_numpy<'py>(
py: Python<'py>,
data: &ArrayData,
shape: &[usize],
) -> PyResult<PyObject> {
let dyn_shape = IxDyn(shape);
Ok(match data {
ArrayData::U8(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::I8(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::I16(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::U16(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::I32(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::U32(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::I64(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::U64(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::F16(a) => a
.mapv(|v| v.to_f32())
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::BF16(a) => a
.mapv(|v| v.to_f32())
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::F32(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
ArrayData::F64(a) => a
.to_owned()
.into_pyarray(py)
.reshape(dyn_shape)?
.unbind()
.into_any(),
})
}
fn to_numpy_view<'py>(
py: Python<'py>,
image: &RustNiftiImage,
) -> Option<Bound<'py, PyArrayDyn<f32>>> {
if let Some(view) = image.as_view_f32() {
let arr = PyArrayDyn::from_array(py, &view);
return Some(arr);
}
None
}
fn to_numpy_array<'py>(py: Python<'py>, image: &RustNiftiImage) -> Bound<'py, PyArrayDyn<f32>> {
let data = image.to_f32();
let shape: Vec<usize> = data.shape().to_vec();
let slice = data.as_slice_memory_order().expect("Array should be contiguous");
let np = py.import("numpy").expect("numpy import failed");
let flat = slice.to_vec().into_pyarray(py);
np.call_method(
"reshape",
(flat, &shape),
Some(&[("order", "F")].into_py_dict(py).expect("dict failed")),
)
.expect("reshape failed")
.extract::<Bound<'py, PyArrayDyn<f32>>>()
.expect("extract failed")
}
fn to_numpy_view_native<'py>(py: Python<'py>, image: &RustNiftiImage) -> Option<PyObject> {
let arr_obj = match image.dtype() {
DataType::UInt8 => image
.as_view_t::<u8>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Int8 => image
.as_view_t::<i8>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Int16 => image
.as_view_t::<i16>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::UInt16 => image
.as_view_t::<u16>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Int32 => image
.as_view_t::<i32>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::UInt32 => image
.as_view_t::<u32>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Int64 => image
.as_view_t::<i64>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::UInt64 => image
.as_view_t::<u64>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Float16 => None, DataType::BFloat16 => None, DataType::Float32 => image
.as_view_t::<f32>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
DataType::Float64 => image
.as_view_t::<f64>()
.map(|v| PyArrayDyn::from_array(py, &v).unbind().into_any()),
};
arr_obj
}
fn torch_dtype(py: Python<'_>, dtype: DataType) -> Option<PyObject> {
let torch = py.import("torch").ok()?;
let dt = match dtype {
DataType::UInt8 => "uint8",
DataType::Int8 => "int8",
DataType::Int16 => "int16",
DataType::UInt16 => return None,
DataType::Int32 => "int32",
DataType::UInt32 => return None,
DataType::Int64 => "int64",
DataType::UInt64 => return None,
DataType::Float16 => "float16",
DataType::BFloat16 => "bfloat16",
DataType::Float32 => "float32",
DataType::Float64 => "float64",
};
torch.getattr(dt).ok().map(|o| o.unbind())
}
#[pyfunction]
#[pyo3(signature = (image_path, label_path, patch_size, pos_neg_ratio=None, min_pos_samples=None, seed=None))]
fn load_label_aware_cropped(
_py: Python<'_>,
image_path: &str,
label_path: &str,
patch_size: Vec<usize>,
pos_neg_ratio: Option<f64>,
min_pos_samples: Option<usize>,
seed: Option<u64>,
) -> PyResult<(PyNiftiImage, PyNiftiImage)> {
let image = nifti::load(image_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load image: {}", e))
})?;
let label = nifti::load(label_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load label: {}", e))
})?;
let config = transforms::RandCropByPosNegLabelConfig {
patch_size: [patch_size[0], patch_size[1], patch_size[2]],
pos_neg_ratio: pos_neg_ratio.unwrap_or(1.0) as f32,
min_pos_samples: min_pos_samples.unwrap_or(4),
seed,
background_label: 0.0,
};
let crop_regions = compute_label_aware_crop_regions(&config, &image, &label, 1);
if crop_regions.is_empty() {
return Err(PyValueError::new_err("No valid crop regions found"));
}
let region = &crop_regions[0];
let cropped_image =
nifti::load_cropped(image_path, region.start, region.size).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Failed to load cropped image: {}",
e
))
})?;
let cropped_label =
nifti::load_cropped(label_path, region.start, region.size).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!(
"Failed to load cropped label: {}",
e
))
})?;
Ok((
PyNiftiImage {
inner: cropped_image,
},
PyNiftiImage {
inner: cropped_label,
},
))
}
#[pyfunction]
#[pyo3(signature = (image_path, label_path, patch_size, num_samples, pos_neg_ratio=None, min_pos_samples=None, seed=None))]
fn compute_crop_regions(
py: Python<'_>,
image_path: &str,
label_path: &str,
patch_size: Vec<usize>,
num_samples: usize,
pos_neg_ratio: Option<f64>,
min_pos_samples: Option<usize>,
seed: Option<u64>,
) -> PyResult<Vec<PyObject>> {
let image = nifti::load(image_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load image: {}", e))
})?;
let label = nifti::load(label_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load label: {}", e))
})?;
let config = transforms::RandCropByPosNegLabelConfig {
patch_size: [patch_size[0], patch_size[1], patch_size[2]],
pos_neg_ratio: pos_neg_ratio.unwrap_or(1.0) as f32,
min_pos_samples: min_pos_samples.unwrap_or(4),
seed,
background_label: 0.0,
};
let crop_regions = compute_label_aware_crop_regions(&config, &image, &label, num_samples);
let mut regions_py = Vec::new();
for region in crop_regions {
let region_dict = pyo3::types::PyDict::new(py);
region_dict.set_item("start", region.start.to_vec())?;
region_dict.set_item("end", region.end.to_vec())?;
region_dict.set_item("size", region.size.to_vec())?;
regions_py.push(region_dict.into());
}
Ok(regions_py)
}
#[pyfunction]
#[pyo3(signature = (image_path, patch_size, num_samples, seed=None, allow_smaller=None))]
fn compute_random_spatial_crops(
py: Python<'_>,
image_path: &str,
patch_size: Vec<usize>,
num_samples: usize,
seed: Option<u64>,
allow_smaller: Option<bool>,
) -> PyResult<Vec<PyObject>> {
let image = nifti::load(image_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load image: {}", e))
})?;
let config = transforms::SpatialCropConfig {
patch_size: [patch_size[0], patch_size[1], patch_size[2]],
seed,
allow_smaller: allow_smaller.unwrap_or(false),
};
let crop_regions = compute_random_spatial_crop_regions(&config, &image, num_samples);
let mut regions_py = Vec::new();
for region in crop_regions {
let region_dict = pyo3::types::PyDict::new(py);
region_dict.set_item("start", region.start.to_vec())?;
region_dict.set_item("end", region.end.to_vec())?;
region_dict.set_item("size", region.size.to_vec())?;
regions_py.push(region_dict.into());
}
Ok(regions_py)
}
#[pyfunction]
fn compute_center_crop(
py: Python<'_>,
image_path: &str,
patch_size: Vec<usize>,
) -> PyResult<PyObject> {
let image = nifti::load(image_path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(format!("Failed to load image: {}", e))
})?;
let region = compute_center_crop_regions([patch_size[0], patch_size[1], patch_size[2]], &image);
let region_dict = pyo3::types::PyDict::new(py);
region_dict.set_item("start", region.start.to_vec())?;
region_dict.set_item("end", region.end.to_vec())?;
region_dict.set_item("size", region.size.to_vec())?;
Ok(region_dict.into())
}