use crate::error::{Error, Result};
use crate::nifti::image::ArrayData;
use crate::nifti::{DataType, NiftiImage};
use crate::pipeline::acquire_buffer;
use ndarray::{ArrayD, IxDyn, ShapeBuilder};
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use rayon::prelude::*;
#[allow(clippy::option_if_let_else)] fn get_rng(seed: Option<u64>) -> ChaCha8Rng {
match seed {
Some(s) => ChaCha8Rng::seed_from_u64(s),
None => ChaCha8Rng::from_entropy(),
}
}
#[must_use = "this function returns a new image and does not modify the original"]
pub fn random_flip(
image: &NiftiImage,
axes: &[usize],
prob: Option<f32>,
seed: Option<u64>,
) -> Result<NiftiImage> {
let prob = prob.unwrap_or(0.5);
let mut rng = get_rng(seed);
let flip_axes: Vec<usize> = axes
.iter()
.filter(|_| rng.gen::<f32>() < prob)
.copied()
.collect();
if flip_axes.is_empty() {
Ok(image.clone())
} else {
crate::transforms::flip(image, &flip_axes)
}
}
#[must_use = "this function returns a Result and does not modify the original"]
pub fn random_gaussian_noise(
image: &NiftiImage,
std: Option<f32>,
seed: Option<u64>,
) -> Result<NiftiImage> {
let std = std.unwrap_or(0.1);
let mut rng = get_rng(seed);
let data = image.to_f32()?;
let slice = data.as_slice_memory_order().ok_or_else(|| {
Error::NonContiguousArray("Array must be contiguous for noise operation".to_string())
})?;
let len = slice.len();
let mut output = acquire_buffer(len);
for (i, out) in output.iter_mut().enumerate() {
let u1: f32 = rng.gen();
let u2: f32 = rng.gen();
let u1 = u1.max(1e-10);
let noise = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos() * std;
*out = slice[i] + noise;
}
let shape = data.shape();
let out_array = ArrayD::from_shape_vec(IxDyn(shape).f(), output)
.map_err(|e| Error::MemoryAllocation(format!("Failed to create output array: {}", e)))?;
let mut header = image.header().clone();
header.datatype = DataType::Float32;
header.scl_slope = 1.0;
header.scl_inter = 0.0;
Ok(NiftiImage::from_parts(header, ArrayData::F32(out_array)))
}
#[must_use = "this function returns a Result and does not modify the original"]
pub fn random_intensity_scale(
image: &NiftiImage,
scale_range: Option<f32>,
seed: Option<u64>,
) -> Result<NiftiImage> {
let scale_range = scale_range.unwrap_or(0.1);
let mut rng = get_rng(seed);
let scale = 1.0 + rng.gen_range(-scale_range..=scale_range);
let data = image.to_f32()?;
let slice = data.as_slice_memory_order().ok_or_else(|| {
Error::NonContiguousArray("Array must be contiguous for scale operation".to_string())
})?;
let mut output = acquire_buffer(slice.len());
output
.par_iter_mut()
.zip(slice.par_iter())
.for_each(|(out, &v)| {
*out = v * scale;
});
let shape = data.shape();
let out_array = ArrayD::from_shape_vec(IxDyn(shape).f(), output)
.map_err(|e| Error::MemoryAllocation(format!("Failed to create output array: {}", e)))?;
let mut header = image.header().clone();
header.datatype = DataType::Float32;
header.scl_slope = 1.0;
header.scl_inter = 0.0;
Ok(NiftiImage::from_parts(header, ArrayData::F32(out_array)))
}
#[must_use = "this function returns a Result and does not modify the original"]
pub fn random_intensity_shift(
image: &NiftiImage,
shift_range: Option<f32>,
seed: Option<u64>,
) -> Result<NiftiImage> {
let shift_range = shift_range.unwrap_or(0.1);
let mut rng = get_rng(seed);
let shift = rng.gen_range(-shift_range..=shift_range);
let data = image.to_f32()?;
let slice = data.as_slice_memory_order().ok_or_else(|| {
Error::NonContiguousArray("Array must be contiguous for shift operation".to_string())
})?;
let mut output = acquire_buffer(slice.len());
output
.par_iter_mut()
.zip(slice.par_iter())
.for_each(|(out, &v)| {
*out = v + shift;
});
let shape = data.shape();
let out_array = ArrayD::from_shape_vec(IxDyn(shape).f(), output)
.map_err(|e| Error::MemoryAllocation(format!("Failed to create output array: {}", e)))?;
let mut header = image.header().clone();
header.datatype = DataType::Float32;
header.scl_slope = 1.0;
header.scl_inter = 0.0;
Ok(NiftiImage::from_parts(header, ArrayData::F32(out_array)))
}
#[must_use = "this function returns a new image and does not modify the original"]
pub fn random_rotate_90(
image: &NiftiImage,
axes: (usize, usize),
seed: Option<u64>,
) -> Result<NiftiImage> {
let mut rng = get_rng(seed);
let k: usize = rng.gen_range(0..4);
if k == 0 {
return Ok(image.clone());
}
rotate_90(image, axes, k)
}
fn rotate_90(image: &NiftiImage, axes: (usize, usize), k: usize) -> Result<NiftiImage> {
let ndim = image.ndim();
if axes.0 >= ndim || axes.1 >= ndim {
return Err(Error::InvalidDimensions(format!(
"Rotation axes ({}, {}) out of bounds for {}D image",
axes.0, axes.1, ndim
)));
}
if axes.0 == axes.1 {
return Err(Error::InvalidDimensions(
"Rotation axes must be different".into(),
));
}
let k = k % 4;
if k == 0 {
return Ok(image.clone());
}
macro_rules! rotate_array {
($arr:expr, $variant:ident) => {{
use ndarray::Axis;
let mut arr = $arr.clone();
for _ in 0..k {
arr.swap_axes(axes.0, axes.1);
arr.invert_axis(Axis(axes.1));
}
ArrayData::$variant(arr.as_standard_layout().to_owned())
}};
}
let data = image.data_cow()?;
let new_data = match data.as_ref() {
ArrayData::U8(a) => rotate_array!(a, U8),
ArrayData::I8(a) => rotate_array!(a, I8),
ArrayData::I16(a) => rotate_array!(a, I16),
ArrayData::U16(a) => rotate_array!(a, U16),
ArrayData::I32(a) => rotate_array!(a, I32),
ArrayData::U32(a) => rotate_array!(a, U32),
ArrayData::I64(a) => rotate_array!(a, I64),
ArrayData::U64(a) => rotate_array!(a, U64),
ArrayData::F16(a) => rotate_array!(a, F16),
ArrayData::BF16(a) => rotate_array!(a, BF16),
ArrayData::F32(a) => rotate_array!(a, F32),
ArrayData::F64(a) => rotate_array!(a, F64),
};
let mut header = image.header().clone();
let old_shape = image.shape();
let mut new_shape = old_shape.to_vec();
if k % 2 == 1 {
new_shape.swap(axes.0, axes.1);
}
header.ndim = new_shape.len() as u8;
header.dim = [1u16; 7];
for (i, &s) in new_shape.iter().enumerate() {
if s > u16::MAX as usize {
return Err(Error::InvalidDimensions(format!(
"Rotated dimension {} ({}) exceeds maximum value {}",
i,
s,
u16::MAX
)));
}
header.dim[i] = s as u16;
}
Ok(NiftiImage::from_parts(header, new_data))
}
#[must_use = "this function returns a Result and does not modify the original"]
pub fn random_gamma(
image: &NiftiImage,
gamma_range: Option<(f32, f32)>,
seed: Option<u64>,
) -> Result<NiftiImage> {
let (gamma_min, gamma_max) = gamma_range.unwrap_or((0.7, 1.5));
let mut rng = get_rng(seed);
let gamma = rng.gen_range(gamma_min..=gamma_max);
let data = image.to_f32()?;
let slice = data.as_slice_memory_order().ok_or_else(|| {
Error::NonContiguousArray("Array must be contiguous for gamma operation".to_string())
})?;
let mut output = acquire_buffer(slice.len());
output
.par_iter_mut()
.zip(slice.par_iter())
.for_each(|(out, &v)| {
let v_clamped = v.max(0.0);
*out = v_clamped.powf(gamma);
});
let shape = data.shape();
let out_array = ArrayD::from_shape_vec(IxDyn(shape).f(), output)
.map_err(|e| Error::MemoryAllocation(format!("Failed to create output array: {}", e)))?;
let mut header = image.header().clone();
header.datatype = DataType::Float32;
header.scl_slope = 1.0;
header.scl_inter = 0.0;
Ok(NiftiImage::from_parts(header, ArrayData::F32(out_array)))
}
#[must_use = "this function returns a Result and does not modify the original"]
pub fn random_augment(image: &NiftiImage, seed: Option<u64>) -> Result<NiftiImage> {
let mut rng = get_rng(seed);
let flip_seed = rng.gen();
let scale_seed = rng.gen();
let shift_seed = rng.gen();
let noise_seed = rng.gen();
let result = random_flip(image, &[0, 1, 2], Some(0.5), Some(flip_seed))?;
let result = random_intensity_scale(&result, Some(0.1), Some(scale_seed))?;
let result = random_intensity_shift(&result, Some(0.1), Some(shift_seed))?;
let result = random_gaussian_noise(&result, Some(0.05), Some(noise_seed))?;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::ArrayD;
fn create_test_image(data: Vec<f32>, shape: [usize; 3]) -> NiftiImage {
use ndarray::ShapeBuilder;
let c_order = ArrayD::from_shape_vec(shape.to_vec(), data).unwrap();
let mut f_order = ArrayD::zeros(ndarray::IxDyn(&shape).f());
f_order.assign(&c_order);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
NiftiImage::from_array(f_order, affine)
}
#[test]
fn test_random_flip_with_seed() {
let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let result1 = random_flip(&img, &[0, 1, 2], Some(0.5), Some(42)).unwrap();
let result2 = random_flip(&img, &[0, 1, 2], Some(0.5), Some(42)).unwrap();
let d1 = result1.to_f32().unwrap();
let d2 = result2.to_f32().unwrap();
assert_eq!(d1, d2);
}
#[test]
fn test_random_gaussian_noise() {
let data = vec![0.5; 64];
let img = create_test_image(data.clone(), [4, 4, 4]);
let noisy = random_gaussian_noise(&img, Some(0.1), Some(42)).unwrap();
let result = noisy.to_f32().unwrap();
let original_sum: f32 = data.iter().sum();
let result_sum: f32 = result.iter().sum();
assert!((original_sum - result_sum).abs() > 0.01);
assert_eq!(result.shape(), &[4, 4, 4]);
}
#[test]
fn test_random_intensity_scale() {
let data: Vec<f32> = (1..=64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let scaled = random_intensity_scale(&img, Some(0.2), Some(42)).unwrap();
let result = scaled.to_f32().unwrap();
assert_eq!(result.shape(), &[4, 4, 4]);
let original = img.to_f32().unwrap();
assert_ne!(result, original);
}
#[test]
fn test_random_intensity_shift() {
let data: Vec<f32> = (1..=64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let shifted = random_intensity_shift(&img, Some(0.5), Some(42)).unwrap();
let result = shifted.to_f32().unwrap();
assert_eq!(result.shape(), &[4, 4, 4]);
}
#[test]
fn test_random_rotate_90() {
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let img = create_test_image(data, [2, 3, 4]);
let rotated = random_rotate_90(&img, (0, 1), Some(42)).unwrap();
let shape = rotated.shape();
assert!(shape.iter().product::<usize>() == 24); }
#[test]
fn test_rotate_90_k0() {
let data: Vec<f32> = (0..8).map(|i| i as f32).collect();
let img = create_test_image(data.clone(), [2, 2, 2]);
let rotated = rotate_90(&img, (0, 1), 0).unwrap();
assert_eq!(rotated.shape(), &[2, 2, 2]);
}
#[test]
fn test_rotate_90_invalid_axes() {
let data: Vec<f32> = (0..8).map(|i| i as f32).collect();
let img = create_test_image(data, [2, 2, 2]);
let result = rotate_90(&img, (0, 5), 1);
assert!(result.is_err());
let result = rotate_90(&img, (1, 1), 1);
assert!(result.is_err());
}
#[test]
fn test_random_gamma() {
let data: Vec<f32> = (0..64).map(|i| (i as f32) / 63.0).collect();
let img = create_test_image(data, [4, 4, 4]);
let gamma_corrected = random_gamma(&img, Some((0.5, 2.0)), Some(42)).unwrap();
let result = gamma_corrected.to_f32().unwrap();
assert_eq!(result.shape(), &[4, 4, 4]);
for &v in result.iter() {
assert!(v >= 0.0);
}
}
#[test]
fn test_random_augment() {
let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let augmented = random_augment(&img, Some(42)).unwrap();
assert_eq!(augmented.shape(), &[4, 4, 4]);
}
#[test]
fn test_reproducibility() {
let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let aug1 = random_augment(&img, Some(12345)).unwrap();
let aug2 = random_augment(&img, Some(12345)).unwrap();
let d1 = aug1.to_f32().unwrap();
let d2 = aug2.to_f32().unwrap();
for (v1, v2) in d1.iter().zip(d2.iter()) {
assert!((v1 - v2).abs() < 1e-6);
}
}
}