use super::lazy::{execute_fused_intensity, LazyImage, LazyTransform, PendingOp};
use crate::error::Result;
use crate::nifti::NiftiImage;
pub trait Transform {
fn apply(&self, image: &NiftiImage) -> Result<NiftiImage>;
}
pub struct Compose {
transforms: Vec<Box<dyn TransformBox>>,
lazy: bool,
}
trait TransformBox: Send + Sync {
fn apply_eager(&self, image: &NiftiImage) -> Result<NiftiImage>;
fn to_pending(&self, image: &LazyImage) -> Option<Vec<PendingOp>>;
fn requires_data(&self) -> bool;
}
impl<T: Transform + LazyTransform + Send + Sync + 'static> TransformBox for T {
fn apply_eager(&self, image: &NiftiImage) -> Result<NiftiImage> {
self.apply(image)
}
fn to_pending(&self, image: &LazyImage) -> Option<Vec<PendingOp>> {
self.to_pending_op(image)
}
fn requires_data(&self) -> bool {
LazyTransform::requires_data(self)
}
}
impl Compose {
pub fn new() -> Self {
Self {
transforms: Vec::new(),
lazy: true,
}
}
pub fn eager() -> Self {
Self {
transforms: Vec::new(),
lazy: false,
}
}
pub fn push<T: Transform + LazyTransform + Send + Sync + 'static>(
mut self,
transform: T,
) -> Self {
self.transforms.push(Box::new(transform));
self
}
pub fn lazy(mut self, lazy: bool) -> Self {
self.lazy = lazy;
self
}
pub fn apply(&self, image: &NiftiImage) -> Result<NiftiImage> {
if self.lazy {
self.apply_lazy(image)
} else {
self.apply_eager(image)
}
}
fn apply_eager(&self, image: &NiftiImage) -> Result<NiftiImage> {
let mut result = image.clone();
for transform in &self.transforms {
result = transform.apply_eager(&result)?;
}
Ok(result)
}
fn apply_lazy(&self, image: &NiftiImage) -> Result<NiftiImage> {
let mut lazy_img = LazyImage::from_image(image.clone());
for transform in &self.transforms {
if transform.requires_data() {
if lazy_img.has_pending() {
let img = lazy_img.materialize()?;
let result = transform.apply_eager(&img)?;
lazy_img = LazyImage::from_image(result);
} else if let Some(img) = lazy_img.image.take() {
let result = transform.apply_eager(&img)?;
lazy_img = LazyImage::from_image(result);
}
} else if let Some(ops) = transform.to_pending(&lazy_img) {
for op in ops {
lazy_img.push_op(op);
}
} else {
let img = lazy_img.materialize()?;
let result = transform.apply_eager(&img)?;
lazy_img = LazyImage::from_image(result);
}
}
if lazy_img.has_pending() {
let pending = lazy_img.pending.clone();
if let Some(img) = &lazy_img.image {
if let Some(fused) = execute_fused_intensity(img, &pending) {
return Ok(fused);
}
}
}
lazy_img.materialize()
}
pub fn len(&self) -> usize {
self.transforms.len()
}
pub fn is_empty(&self) -> bool {
self.transforms.is_empty()
}
}
impl Default for Compose {
fn default() -> Self {
Self::new()
}
}
pub struct TransformPipeline {
compose: Compose,
}
impl TransformPipeline {
pub fn new() -> Self {
Self {
compose: Compose::new(),
}
}
pub fn z_normalize(self) -> Self {
Self {
compose: self.compose.push(ZNormalizeTransform),
}
}
pub fn rescale(self, out_min: f32, out_max: f32) -> Self {
Self {
compose: self.compose.push(RescaleTransform { out_min, out_max }),
}
}
pub fn clamp(self, min: f32, max: f32) -> Self {
Self {
compose: self.compose.push(ClampTransform { min, max }),
}
}
pub fn resample_to_spacing(self, spacing: [f32; 3]) -> Self {
Self {
compose: self.compose.push(ResampleSpacingTransform { spacing }),
}
}
pub fn resample_to_shape(self, shape: [usize; 3]) -> Self {
Self {
compose: self.compose.push(ResampleShapeTransform { shape }),
}
}
pub fn flip(self, axes: &[usize]) -> Self {
let mut mask = 0u8;
for &axis in axes {
mask |= 1 << axis;
}
Self {
compose: self.compose.push(FlipTransform { axes: mask }),
}
}
pub fn lazy(mut self, lazy: bool) -> Self {
self.compose = self.compose.lazy(lazy);
self
}
pub fn apply(&self, image: &NiftiImage) -> Result<NiftiImage> {
self.compose.apply(image)
}
pub fn into_compose(self) -> Compose {
self.compose
}
}
impl Default for TransformPipeline {
fn default() -> Self {
Self::new()
}
}
struct ZNormalizeTransform;
impl Transform for ZNormalizeTransform {
fn apply(&self, image: &NiftiImage) -> Result<NiftiImage> {
crate::transforms::z_normalization(image)
}
}
impl LazyTransform for ZNormalizeTransform {
fn to_pending_op(&self, _image: &LazyImage) -> Option<Vec<PendingOp>> {
None
}
fn requires_data(&self) -> bool {
true }
}
struct RescaleTransform {
out_min: f32,
out_max: f32,
}
impl Transform for RescaleTransform {
fn apply(&self, image: &NiftiImage) -> Result<NiftiImage> {
crate::transforms::rescale_intensity(image, self.out_min as f64, self.out_max as f64)
}
}
impl LazyTransform for RescaleTransform {
fn to_pending_op(&self, _image: &LazyImage) -> Option<Vec<PendingOp>> {
None
}
fn requires_data(&self) -> bool {
true }
}
struct ClampTransform {
min: f32,
max: f32,
}
impl Transform for ClampTransform {
fn apply(&self, image: &NiftiImage) -> Result<NiftiImage> {
crate::transforms::clamp(image, self.min as f64, self.max as f64)
}
}
impl LazyTransform for ClampTransform {
fn to_pending_op(&self, _image: &LazyImage) -> Option<Vec<PendingOp>> {
Some(vec![PendingOp::Clamp {
min: self.min,
max: self.max,
}])
}
fn requires_data(&self) -> bool {
false
}
}
struct ResampleSpacingTransform {
spacing: [f32; 3],
}
impl Transform for ResampleSpacingTransform {
fn apply(&self, image: &NiftiImage) -> Result<NiftiImage> {
crate::transforms::resample_to_spacing(
image,
self.spacing,
crate::transforms::Interpolation::Trilinear,
)
}
}
impl LazyTransform for ResampleSpacingTransform {
fn to_pending_op(&self, image: &LazyImage) -> Option<Vec<PendingOp>> {
let current_spacing = image.image.as_ref()?.spacing();
let scale_x = self.spacing[0] / current_spacing[0];
let scale_y = self.spacing[1] / current_spacing[1];
let scale_z = self.spacing[2] / current_spacing[2];
let shape = image.image.as_ref()?.shape();
let new_shape = [
((shape[0] as f32) / scale_x).round() as usize,
((shape[1] as f32) / scale_y).round() as usize,
((shape[2] as f32) / scale_z).round() as usize,
];
let matrix = [
[scale_x, 0.0, 0.0, 0.0],
[0.0, scale_y, 0.0, 0.0],
[0.0, 0.0, scale_z, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
Some(vec![PendingOp::Affine {
matrix,
output_shape: Some(new_shape),
interpolation: super::lazy::Interpolation::Trilinear,
}])
}
fn requires_data(&self) -> bool {
false }
}
struct ResampleShapeTransform {
shape: [usize; 3],
}
impl Transform for ResampleShapeTransform {
fn apply(&self, image: &NiftiImage) -> Result<NiftiImage> {
crate::transforms::resample_to_shape(
image,
self.shape,
crate::transforms::Interpolation::Trilinear,
)
}
}
impl LazyTransform for ResampleShapeTransform {
fn to_pending_op(&self, image: &LazyImage) -> Option<Vec<PendingOp>> {
let current_shape = image.image.as_ref()?.shape();
let scale_x = current_shape[0] as f32 / self.shape[0] as f32;
let scale_y = current_shape[1] as f32 / self.shape[1] as f32;
let scale_z = current_shape[2] as f32 / self.shape[2] as f32;
let matrix = [
[scale_x, 0.0, 0.0, 0.0],
[0.0, scale_y, 0.0, 0.0],
[0.0, 0.0, scale_z, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
Some(vec![PendingOp::Affine {
matrix,
output_shape: Some(self.shape),
interpolation: super::lazy::Interpolation::Trilinear,
}])
}
fn requires_data(&self) -> bool {
false }
}
struct FlipTransform {
axes: u8,
}
impl Transform for FlipTransform {
fn apply(&self, image: &NiftiImage) -> Result<NiftiImage> {
let axes_vec: Vec<usize> = (0..3).filter(|&i| (self.axes >> i) & 1 == 1).collect();
crate::transforms::flip(image, &axes_vec)
}
}
impl LazyTransform for FlipTransform {
fn to_pending_op(&self, _image: &LazyImage) -> Option<Vec<PendingOp>> {
Some(vec![PendingOp::Flip { axes: self.axes }])
}
fn requires_data(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nifti::NiftiImage;
use ndarray::ArrayD;
fn create_test_image(data: Vec<f32>, shape: [usize; 3]) -> NiftiImage {
use ndarray::ShapeBuilder;
let c_order = ArrayD::from_shape_vec(shape.to_vec(), data).unwrap();
let mut f_order = ArrayD::zeros(ndarray::IxDyn(&shape).f());
f_order.assign(&c_order);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
NiftiImage::from_array(f_order, affine)
}
#[test]
fn test_compose_basic() {
let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let pipeline = Compose::new()
.push(ZNormalizeTransform)
.push(ClampTransform {
min: -2.0,
max: 2.0,
});
let result = pipeline.apply(&img).expect("pipeline should succeed");
assert_eq!(result.shape(), &[4, 4, 4]);
}
#[test]
fn test_compose_eager_vs_lazy() {
let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let eager_pipeline = Compose::eager().push(ClampTransform {
min: 0.0,
max: 50.0,
});
let eager_result = eager_pipeline
.apply(&img)
.expect("eager pipeline should succeed");
let lazy_pipeline = Compose::new().push(ClampTransform {
min: 0.0,
max: 50.0,
});
let lazy_result = lazy_pipeline
.apply(&img)
.expect("lazy pipeline should succeed");
assert_eq!(eager_result.shape(), lazy_result.shape());
}
#[test]
fn test_transform_pipeline_fluent() {
let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let pipeline = TransformPipeline::new()
.z_normalize()
.clamp(-1.0, 1.0)
.flip(&[0]);
let result = pipeline.apply(&img).expect("pipeline should succeed");
assert_eq!(result.shape(), &[4, 4, 4]);
}
#[test]
fn test_compose_resample_shape() {
let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let pipeline = TransformPipeline::new().resample_to_shape([8, 8, 8]);
let result = pipeline.apply(&img).expect("pipeline should succeed");
assert_eq!(result.shape(), &[8, 8, 8]);
}
#[test]
fn test_compose_chain_resamples() {
let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let pipeline = TransformPipeline::new()
.resample_to_shape([8, 8, 8])
.resample_to_shape([16, 16, 16]);
let result = pipeline.apply(&img).expect("pipeline should succeed");
assert_eq!(result.shape(), &[16, 16, 16]);
}
#[test]
fn test_compose_intensity_then_resample() {
let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
let img = create_test_image(data, [4, 4, 4]);
let pipeline = TransformPipeline::new()
.z_normalize()
.clamp(-1.0, 1.0)
.resample_to_shape([8, 8, 8]);
let result = pipeline.apply(&img).expect("pipeline should succeed");
assert_eq!(result.shape(), &[8, 8, 8]);
let data = result.to_f32().unwrap();
for &v in data.iter() {
assert!(v >= -1.0 && v <= 1.0, "Value {} outside clamp range", v);
}
}
#[test]
fn test_compose_is_empty() {
let empty = Compose::new();
assert!(empty.is_empty());
assert_eq!(empty.len(), 0);
let with_transform = Compose::new().push(ZNormalizeTransform);
assert!(!with_transform.is_empty());
assert_eq!(with_transform.len(), 1);
}
}