use crate::nifti::image::ArrayData;
use crate::nifti::DataType;
use crate::nifti::NiftiImage;
use crate::pipeline::acquire_buffer;
use crate::pipeline::simd_kernels::{
parallel_linear_transform_f32, parallel_minmax_f32, parallel_sum_and_sum_sq_f32,
};
use crate::transforms::Interpolation as TransformsInterpolation;
use ndarray::{ArrayD, IxDyn};
#[derive(Clone, Debug)]
pub enum PendingOp {
Affine {
matrix: [[f32; 4]; 4],
output_shape: Option<[usize; 3]>,
interpolation: Interpolation,
},
ZNormalize {
mean: f32,
inv_std: f32,
},
LinearIntensity {
scale: f32,
offset: f32,
},
Clamp {
min: f32,
max: f32,
},
Flip {
axes: u8,
},
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum Interpolation {
Nearest,
#[default]
Trilinear,
}
impl PendingOp {
pub fn can_fuse_with(&self, other: &PendingOp) -> bool {
match (self, other) {
(
PendingOp::Affine {
interpolation: i1, ..
},
PendingOp::Affine {
interpolation: i2, ..
},
) => i1 == i2,
(PendingOp::LinearIntensity { .. }, PendingOp::LinearIntensity { .. })
| (PendingOp::ZNormalize { .. }, PendingOp::LinearIntensity { .. })
| (PendingOp::ZNormalize { .. }, PendingOp::Clamp { .. })
| (PendingOp::LinearIntensity { .. }, PendingOp::Clamp { .. }) => true,
_ => false,
}
}
pub fn fuse_with(&self, other: &PendingOp) -> Option<PendingOp> {
match (self, other) {
(
PendingOp::Affine {
matrix: m1,
interpolation,
..
},
PendingOp::Affine {
matrix: m2,
output_shape,
..
},
) => {
let composed = compose_affine(m1, m2);
Some(PendingOp::Affine {
matrix: composed,
output_shape: *output_shape,
interpolation: *interpolation,
})
}
(
PendingOp::LinearIntensity {
scale: s1,
offset: o1,
},
PendingOp::LinearIntensity {
scale: s2,
offset: o2,
},
) => {
Some(PendingOp::LinearIntensity {
scale: s1 * s2,
offset: o1 * s2 + o2,
})
}
(
PendingOp::ZNormalize { mean, inv_std },
PendingOp::LinearIntensity { scale, offset },
) => {
Some(PendingOp::LinearIntensity {
scale: inv_std * scale,
offset: -mean * inv_std * scale + offset,
})
}
(PendingOp::ZNormalize { mean, inv_std }, PendingOp::Clamp { min, max }) => {
let linear = PendingOp::LinearIntensity {
scale: *inv_std,
offset: -mean * inv_std,
};
let clamp = PendingOp::Clamp {
min: *min,
max: *max,
};
linear.fuse_with(&clamp)
}
(PendingOp::LinearIntensity { .. }, PendingOp::Clamp { min, max }) => {
Some(PendingOp::Clamp {
min: *min,
max: *max,
})
}
_ => None,
}
}
}
fn compose_affine(a: &[[f32; 4]; 4], b: &[[f32; 4]; 4]) -> [[f32; 4]; 4] {
let mut result = [[0.0f32; 4]; 4];
for i in 0..4 {
for j in 0..4 {
for k in 0..4 {
result[i][j] += b[i][k] * a[k][j];
}
}
}
result
}
#[derive(Clone)]
pub struct LazyImage {
pub(crate) image: Option<NiftiImage>,
pub(crate) path: Option<String>,
pub(crate) pending: Vec<PendingOp>,
}
impl LazyImage {
pub fn from_image(image: NiftiImage) -> Self {
Self {
image: Some(image),
path: None,
pending: Vec::new(),
}
}
pub fn from_path(path: impl Into<String>) -> Self {
Self {
image: None,
path: Some(path.into()),
pending: Vec::new(),
}
}
pub fn push_op(&mut self, op: PendingOp) {
if let Some(last) = self.pending.last() {
if last.can_fuse_with(&op) {
if let Some(fused) = last.fuse_with(&op) {
self.pending.pop();
self.pending.push(fused);
return;
}
}
}
self.pending.push(op);
}
pub fn has_pending(&self) -> bool {
!self.pending.is_empty()
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn materialize(self) -> crate::error::Result<NiftiImage> {
let mut image = if let Some(img) = self.image {
img
} else if let Some(path) = &self.path {
crate::nifti::load(path)?
} else {
return Err(crate::error::Error::InvalidDimensions(
"LazyImage has no image or path".into(),
));
};
for op in &self.pending {
image = execute_op(image, op)?;
}
Ok(image)
}
pub fn pending_ops(&self) -> &[PendingOp] {
&self.pending
}
}
fn execute_op(image: NiftiImage, op: &PendingOp) -> crate::error::Result<NiftiImage> {
use crate::transforms;
match op {
PendingOp::Affine {
matrix,
output_shape,
interpolation,
} => {
let shape = output_shape.unwrap_or_else(|| {
let shp = image.shape();
[shp[0], shp[1], shp[2]]
});
let interp = match interpolation {
Interpolation::Nearest => TransformsInterpolation::Nearest,
Interpolation::Trilinear => TransformsInterpolation::Trilinear,
};
apply_affine(&image, matrix, shape, interp)
}
PendingOp::ZNormalize { .. } => {
transforms::z_normalization(&image)
}
PendingOp::LinearIntensity { scale, offset } => {
apply_linear_intensity(&image, *scale, *offset)
}
PendingOp::Clamp { min, max } => transforms::clamp(&image, *min as f64, *max as f64),
PendingOp::Flip { axes } => {
let axes_vec: Vec<usize> = (0..3).filter(|&i| (axes >> i) & 1 == 1).collect();
transforms::flip(&image, &axes_vec)
}
}
}
pub fn execute_fused_intensity(image: &NiftiImage, pending: &[PendingOp]) -> Option<NiftiImage> {
let mut do_znorm = false;
let mut accumulated_scale = 1.0f32;
let mut accumulated_offset = 0.0f32;
let mut has_linear = false;
let mut clamp: Option<(f32, f32)> = None;
for op in pending {
match op {
PendingOp::ZNormalize { .. } => do_znorm = true,
PendingOp::LinearIntensity { scale, offset } => {
accumulated_offset = accumulated_offset * scale + offset;
accumulated_scale *= scale;
has_linear = true;
}
PendingOp::Clamp { min, max } => clamp = Some((*min, *max)),
_ => return None,
}
}
let rescale = None;
if has_linear && !do_znorm && clamp.is_none() {
use ndarray::ShapeBuilder;
let data = image.to_f32().ok()?;
let slice = data.as_slice_memory_order()?;
let mut output = acquire_buffer(slice.len());
parallel_linear_transform_f32(slice, &mut output, accumulated_scale, accumulated_offset);
let out_array =
ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(data.shape()).f(), output).ok()?;
let mut header = image.header().clone();
header.datatype = DataType::Float32;
header.scl_slope = 1.0;
header.scl_inter = 0.0;
return Some(NiftiImage::from_parts(header, ArrayData::F32(out_array)));
}
fuse_intensity_ops(image, do_znorm, rescale, clamp)
}
fn apply_linear_intensity(
image: &NiftiImage,
scale: f32,
offset: f32,
) -> crate::error::Result<NiftiImage> {
use super::acquire_buffer;
use super::simd_kernels::parallel_linear_transform_f32;
use crate::error::Error;
use crate::nifti::image::ArrayData;
use crate::nifti::DataType;
use ndarray::{ArrayD, IxDyn, ShapeBuilder};
let header = image.header().clone();
if let ArrayData::F32(a) = image.owned_data()? {
let slice = a
.as_slice_memory_order()
.ok_or_else(|| Error::InvalidDimensions("Array not contiguous".into()))?;
let mut output = acquire_buffer(slice.len());
parallel_linear_transform_f32(slice, &mut output, scale, offset);
let out_array = ArrayD::from_shape_vec(IxDyn(a.shape()).f(), output)
.map_err(|e| Error::InvalidDimensions(format!("Shape mismatch: {}", e)))?;
return Ok(NiftiImage::from_parts(header, ArrayData::F32(out_array)));
}
let data = image.to_f32()?;
let slice = data
.as_slice_memory_order()
.ok_or_else(|| Error::InvalidDimensions("Array not contiguous".into()))?;
let mut output = acquire_buffer(slice.len());
parallel_linear_transform_f32(slice, &mut output, scale, offset);
let out_array = ArrayD::from_shape_vec(IxDyn(data.shape()).f(), output)
.map_err(|e| Error::InvalidDimensions(format!("Shape mismatch: {}", e)))?;
let mut new_header = header;
new_header.datatype = DataType::Float32;
new_header.scl_slope = 1.0;
new_header.scl_inter = 0.0;
Ok(NiftiImage::from_parts(
new_header,
ArrayData::F32(out_array),
))
}
#[allow(clippy::similar_names)]
pub fn fuse_intensity_ops(
image: &NiftiImage,
do_znorm: bool,
rescale: Option<(f32, f32)>,
clamp: Option<(f32, f32)>,
) -> Option<NiftiImage> {
use ndarray::ShapeBuilder;
let data = image.to_f32().ok()?;
let slice = data.as_slice_memory_order()?;
let mut scale = 1.0f32;
let mut offset = 0.0f32;
if do_znorm {
if slice.is_empty() {
return None;
}
let (sum, sum_sq, count) = parallel_sum_and_sum_sq_f32(slice);
if count == 0 {
return None;
}
let mean = (sum / count as f64) as f32;
let variance = (sum_sq / count as f64) - (mean as f64 * mean as f64);
let inv_std = if variance <= 0.0 {
1.0
} else {
1.0 / (variance.sqrt() as f32)
};
scale *= inv_std;
offset += -mean * inv_std;
}
let mut clamp_min = None;
let mut clamp_max = None;
if let Some((out_min, out_max)) = rescale {
let (min, max) = parallel_minmax_f32(slice);
let range = if max - min == 0.0 { 1.0 } else { max - min };
let r_scale = (out_max - out_min) / range;
let r_offset = out_min - min * r_scale;
scale *= r_scale;
offset = offset * r_scale + r_offset;
}
if let Some((min, max)) = clamp {
clamp_min = Some(min);
clamp_max = Some(max);
}
let mut output = acquire_buffer(slice.len());
match (clamp_min, clamp_max) {
(Some(min), Some(max)) => {
super::simd_kernels::parallel_linear_transform_clamp_f32(
slice,
&mut output,
scale,
offset,
min,
max,
);
}
(Some(min), None) => {
super::simd_kernels::parallel_linear_transform_clamp_f32(
slice,
&mut output,
scale,
offset,
min,
f32::MAX,
);
}
(None, Some(max)) => {
super::simd_kernels::parallel_linear_transform_clamp_f32(
slice,
&mut output,
scale,
offset,
f32::MIN,
max,
);
}
(None, None) => {
parallel_linear_transform_f32(slice, &mut output, scale, offset);
}
}
let out_array = ArrayD::from_shape_vec(IxDyn(data.shape()).f(), output).ok()?;
let mut header = image.header().clone();
header.datatype = DataType::Float32;
header.scl_slope = 1.0;
header.scl_inter = 0.0;
Some(NiftiImage::from_parts(header, ArrayData::F32(out_array)))
}
#[allow(clippy::similar_names)]
fn apply_affine(
image: &NiftiImage,
matrix: &[[f32; 4]; 4],
output_shape: [usize; 3],
interpolation: TransformsInterpolation,
) -> crate::error::Result<NiftiImage> {
use crate::error::Error;
use ndarray::ShapeBuilder;
let data = image.to_f32()?;
let shape = data.shape();
let (id, ih, iw) = (shape[0], shape[1], shape[2]);
let src = data
.as_slice_memory_order()
.ok_or_else(|| Error::InvalidDimensions("Array not contiguous".into()))?;
let stride_z = ih * iw;
let stride_y = iw;
let (od, oh, ow) = (output_shape[0], output_shape[1], output_shape[2]);
let mut out = vec![0.0f32; od * oh * ow];
for z in 0..od {
for y in 0..oh {
for x in 0..ow {
let ox = x as f32;
let oy = y as f32;
let oz = z as f32;
let sx = matrix[0][0] * ox + matrix[0][1] * oy + matrix[0][2] * oz + matrix[0][3];
let sy = matrix[1][0] * ox + matrix[1][1] * oy + matrix[1][2] * oz + matrix[1][3];
let sz = matrix[2][0] * ox + matrix[2][1] * oy + matrix[2][2] * oz + matrix[2][3];
let idx = z * oh * ow + y * ow + x;
if sx < 0.0
|| sy < 0.0
|| sz < 0.0
|| sx > (iw - 1) as f32
|| sy > (ih - 1) as f32
|| sz > (id - 1) as f32
{
out[idx] = 0.0;
continue;
}
match interpolation {
TransformsInterpolation::Nearest => {
let xi = (sx.round() as usize).min(iw - 1);
let yi = (sy.round() as usize).min(ih - 1);
let zi = (sz.round() as usize).min(id - 1);
out[idx] = src[zi * stride_z + yi * stride_y + xi];
}
TransformsInterpolation::Trilinear => {
let x0 = sx.floor() as usize;
let y0 = sy.floor() as usize;
let z0 = sz.floor() as usize;
let x1 = (x0 + 1).min(iw - 1);
let y1 = (y0 + 1).min(ih - 1);
let z1 = (z0 + 1).min(id - 1);
let fx = sx - x0 as f32;
let fy = sy - y0 as f32;
let fz = sz - z0 as f32;
let c000 = src[z0 * stride_z + y0 * stride_y + x0];
let c001 = src[z0 * stride_z + y0 * stride_y + x1];
let c010 = src[z0 * stride_z + y1 * stride_y + x0];
let c011 = src[z0 * stride_z + y1 * stride_y + x1];
let c100 = src[z1 * stride_z + y0 * stride_y + x0];
let c101 = src[z1 * stride_z + y0 * stride_y + x1];
let c110 = src[z1 * stride_z + y1 * stride_y + x0];
let c111 = src[z1 * stride_z + y1 * stride_y + x1];
let c00 = c000 * (1.0 - fx) + c001 * fx;
let c01 = c010 * (1.0 - fx) + c011 * fx;
let c10 = c100 * (1.0 - fx) + c101 * fx;
let c11 = c110 * (1.0 - fx) + c111 * fx;
let c0 = c00 * (1.0 - fy) + c01 * fy;
let c1 = c10 * (1.0 - fy) + c11 * fy;
out[idx] = c0 * (1.0 - fz) + c1 * fz;
}
}
}
}
}
let out_array = ArrayD::from_shape_vec(IxDyn(&[od, oh, ow]).f(), out)
.map_err(|e| Error::InvalidDimensions(format!("Shape mismatch: {}", e)))?;
let mut header = image.header().clone();
header.ndim = 3;
header.dim = [1u16; 7];
header.dim[0] = od as u16;
header.dim[1] = oh as u16;
header.dim[2] = ow as u16;
header.datatype = DataType::Float32;
header.scl_slope = 1.0;
header.scl_inter = 0.0;
Ok(NiftiImage::from_parts(header, ArrayData::F32(out_array)))
}
pub trait LazyTransform {
fn to_pending_op(&self, image: &LazyImage) -> Option<Vec<PendingOp>>;
fn requires_data(&self) -> bool {
false
}
}