use super::header::NiftiHeader;
use super::image::NiftiImage;
use crate::error::{Error, Result};
use flate2::bufread::GzDecoder;
use libdeflater::Decompressor;
use memmap2::Mmap;
use rand::Rng;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Cursor, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::transforms::{self, Interpolation};
#[must_use = "this function returns a loaded image that should be used"]
pub fn load<P: AsRef<Path>>(path: P) -> Result<NiftiImage> {
let path = path.as_ref();
let is_gzipped = path.extension().is_some_and(|e| e == "gz");
if is_gzipped {
load_gzipped(path)
} else {
load_uncompressed(path)
}
}
#[allow(unsafe_code)]
fn load_uncompressed(path: &Path) -> Result<NiftiImage> {
let file = File::open(path)?;
let mmap = unsafe { Mmap::map(&file)? };
let header = NiftiHeader::from_bytes(&mmap)?;
let offset = header.vox_offset as usize;
let data_size = header.data_size();
if mmap.len() < offset + data_size {
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"file truncated",
)));
}
let arc = Arc::new(mmap);
Ok(NiftiImage::from_shared_mmap(header, arc, offset, data_size))
}
fn load_gzipped(path: &Path) -> Result<NiftiImage> {
let compressed = std::fs::read(path)?;
let mut header_buf = [0u8; NiftiHeader::SIZE];
{
let cursor = Cursor::new(&compressed);
let mut decoder = GzDecoder::new(cursor);
decoder.read_exact(&mut header_buf)?;
}
let header = NiftiHeader::from_bytes(&header_buf)?;
let offset = header.vox_offset as usize;
let data_size = header.data_size();
let total_size = offset + data_size;
let mut output = vec![0u8; total_size];
let mut decompressor = Decompressor::new();
let written = decompressor
.gzip_decompress(&compressed, &mut output)
.map_err(|e| Error::Decompression(format!("{}", e)))?;
if written != total_size {
return Err(Error::Decompression(format!(
"decompressed size {} did not match expected {}",
written, total_size
)));
}
let bytes = Arc::new(output);
Ok(NiftiImage::from_shared_bytes(
header, bytes, offset, data_size,
))
}
pub fn save<P: AsRef<Path>>(image: &NiftiImage, path: P) -> Result<()> {
image.header().validate()?;
let path = path.as_ref();
let is_gzipped = path.extension().is_some_and(|e| e == "gz");
if is_gzipped {
save_gzipped(image, path)
} else {
save_uncompressed(image, path)
}
}
fn save_uncompressed(image: &NiftiImage, path: &Path) -> Result<()> {
let file = File::create(path)?;
let mut writer = BufWriter::with_capacity(1024 * 1024, file);
let header_bytes = image.header().to_bytes();
writer.write_all(&header_bytes)?;
let padding = image.header().vox_offset as usize - NiftiHeader::SIZE;
if padding > 0 {
writer.write_all(&vec![0u8; padding])?;
}
let data = image.data_to_bytes()?;
writer.write_all(&data)?;
writer.flush()?;
Ok(())
}
fn save_gzipped(image: &NiftiImage, path: &Path) -> Result<()> {
let header_bytes = image.header().to_bytes();
let padding = image.header().vox_offset as usize - NiftiHeader::SIZE;
let data = image.data_to_bytes()?;
let total_size = header_bytes.len() + padding + data.len();
let mut uncompressed = Vec::with_capacity(total_size);
uncompressed.extend_from_slice(&header_bytes);
uncompressed.resize(uncompressed.len() + padding, 0u8);
uncompressed.extend_from_slice(&data);
let mut compressor = libdeflater::Compressor::new(libdeflater::CompressionLvl::fastest());
let max_compressed_size = compressor.gzip_compress_bound(uncompressed.len());
let mut compressed = vec![0u8; max_compressed_size];
let actual_size = compressor
.gzip_compress(&uncompressed, &mut compressed)
.map_err(|e| Error::Io(std::io::Error::other(format!("compression failed: {e:?}"))))?;
compressed.truncate(actual_size);
let mut file = File::create(path)?;
file.write_all(&compressed)?;
Ok(())
}
#[allow(unsafe_code)]
pub fn load_header<P: AsRef<Path>>(path: P) -> Result<NiftiHeader> {
let path = path.as_ref();
let is_gzipped = path.extension().is_some_and(|e| e == "gz");
if is_gzipped {
let file = File::open(path)?;
let buf_reader = BufReader::new(file);
let mut decoder = GzDecoder::new(buf_reader);
let mut header_buf = vec![0u8; NiftiHeader::SIZE];
decoder.read_exact(&mut header_buf)?;
NiftiHeader::from_bytes(&header_buf)
} else {
let file = File::open(path)?;
let mmap = unsafe { Mmap::map(&file)? };
NiftiHeader::from_bytes(&mmap)
}
}
pub struct LoadCroppedConfig {
pub output_shape: [usize; 3],
pub target_spacing: Option<[f32; 3]>,
pub target_orientation: Option<crate::transforms::Orientation>,
pub output_offset: Option<[usize; 3]>,
}
impl Default for LoadCroppedConfig {
fn default() -> Self {
Self {
output_shape: [64, 64, 64],
target_spacing: None,
target_orientation: None,
output_offset: None,
}
}
}
#[must_use = "this function returns a loaded image that should be used"]
#[allow(unsafe_code)]
pub fn load_cropped_config<P: AsRef<Path>>(
path: P,
config: LoadCroppedConfig,
) -> Result<NiftiImage> {
let path = path.as_ref();
ensure_uncompressed(path)?;
let file = File::open(path)?;
let mmap = unsafe { Mmap::map(&file)? };
let header = NiftiHeader::from_bytes(&mmap)?;
let data_offset = header.vox_offset as usize;
let shape = header.shape();
let crop_offset = config
.output_offset
.unwrap_or(compute_center_offset(&shape, &config.output_shape)?);
let cropped = copy_cropped_region(
&header,
&mmap,
data_offset,
crop_offset,
config.output_shape,
)?;
let mut output = cropped;
if let Some(target_orient) = config.target_orientation {
output = transforms::reorient(&output, target_orient)?;
}
if let Some(target_spacing) = config.target_spacing {
output =
transforms::resample_to_spacing(&output, target_spacing, Interpolation::Trilinear)?;
}
Ok(output)
}
#[must_use = "this function returns a loaded image that should be used"]
#[allow(unsafe_code)]
pub fn load_cropped<P: AsRef<Path>>(
path: P,
crop_offset: [usize; 3],
crop_shape: [usize; 3],
) -> Result<NiftiImage> {
let path = path.as_ref();
ensure_uncompressed(path)?;
let file = File::open(path)?;
let mmap = unsafe { Mmap::map(&file)? };
let header = NiftiHeader::from_bytes(&mmap)?;
let data_offset = header.vox_offset as usize;
copy_cropped_region(&header, &mmap, data_offset, crop_offset, crop_shape)
}
fn ensure_uncompressed(path: &Path) -> Result<()> {
if path.extension().is_some_and(|e| e == "gz") {
return Err(Error::InvalidDimensions(
"load_cropped only supports uncompressed .nii files".to_string(),
));
}
Ok(())
}
fn compute_center_offset(full_shape: &[u16], crop_shape: &[usize; 3]) -> Result<[usize; 3]> {
if full_shape.len() < 3 {
return Err(Error::InvalidDimensions(
"expected at least 3 spatial dimensions".to_string(),
));
}
Ok([
(full_shape[0] as usize).saturating_sub(crop_shape[0]) / 2,
(full_shape[1] as usize).saturating_sub(crop_shape[1]) / 2,
(full_shape[2] as usize).saturating_sub(crop_shape[2]) / 2,
])
}
fn copy_cropped_region(
header: &NiftiHeader,
mmap: &Mmap,
data_offset: usize,
crop_offset: [usize; 3],
crop_shape: [usize; 3],
) -> Result<NiftiImage> {
let full_shape = header.shape();
if full_shape.len() < 3 {
return Err(Error::InvalidDimensions(
"expected at least 3 spatial dimensions".to_string(),
));
}
for (i, &dim) in crop_shape.iter().enumerate() {
if dim == 0 {
return Err(Error::InvalidDimensions(format!(
"Crop shape dimension {} cannot be zero",
i
)));
}
}
for i in 0..3 {
if crop_offset[i] + crop_shape[i] > full_shape[i] as usize {
return Err(Error::InvalidDimensions(format!(
"Crop region exceeds dimension {}: {} + {} > {}",
i, crop_offset[i], crop_shape[i], full_shape[i]
)));
}
}
let elem_size = header.datatype.size();
let dim0 = full_shape[0] as usize; let dim1 = full_shape.get(1).copied().unwrap_or(1) as usize;
let total_bytes = crop_shape[0]
.checked_mul(crop_shape[1])
.and_then(|v| v.checked_mul(crop_shape[2]))
.and_then(|v| v.checked_mul(elem_size))
.ok_or_else(|| {
Error::InvalidDimensions(format!(
"Crop region too large: {:?} x {} bytes would overflow",
crop_shape, elem_size
))
})?;
let mut buffer = vec![0u8; total_bytes];
let expected_data_size = (full_shape[0] as usize)
.checked_mul(full_shape[1] as usize)
.and_then(|v| v.checked_mul(full_shape[2] as usize))
.and_then(|v| v.checked_mul(elem_size))
.ok_or_else(|| {
Error::InvalidDimensions(format!(
"Volume too large: {:?} x {} bytes would overflow",
&full_shape[..3],
elem_size
))
})?;
let total_required = data_offset
.checked_add(expected_data_size)
.ok_or_else(|| Error::InvalidDimensions("Data offset + size would overflow".to_string()))?;
if mmap.len() < total_required {
return Err(Error::InvalidDimensions(format!(
"File too small: need {} bytes for data but mmap is {} bytes (offset {})",
expected_data_size,
mmap.len(),
data_offset
)));
}
let row_bytes = crop_shape[0] * elem_size;
let mmap_slice = mmap.as_ref();
let slice_bytes = crop_shape[0] * crop_shape[1] * elem_size;
let use_parallel = slice_bytes > 65536 && crop_shape[2] >= 4;
if use_parallel {
use rayon::prelude::*;
let slices_per_z = crop_shape[1];
let bytes_per_z = slices_per_z * row_bytes;
buffer
.par_chunks_mut(bytes_per_z)
.enumerate()
.for_each(|(z, z_buffer)| {
let src_z = crop_offset[2] + z;
let z_offset = src_z * dim0 * dim1;
for y in 0..crop_shape[1] {
let src_y = crop_offset[1] + y;
let src_index = crop_offset[0] + src_y * dim0 + z_offset;
let src_byte = data_offset + src_index * elem_size;
let dst_start = y * row_bytes;
z_buffer[dst_start..dst_start + row_bytes]
.copy_from_slice(&mmap_slice[src_byte..src_byte + row_bytes]);
}
});
} else {
let mut dst_cursor = 0;
for z in 0..crop_shape[2] {
let src_z = crop_offset[2] + z;
let z_offset = src_z * dim0 * dim1;
for y in 0..crop_shape[1] {
let src_y = crop_offset[1] + y;
let src_index = crop_offset[0] + src_y * dim0 + z_offset;
let src_byte = data_offset + src_index * elem_size;
let src_range = src_byte..src_byte + row_bytes;
let dst_range = dst_cursor..dst_cursor + row_bytes;
buffer[dst_range].copy_from_slice(&mmap_slice[src_range]);
dst_cursor += row_bytes;
}
}
}
let mut new_header = header.clone();
new_header.ndim = 3;
new_header.dim = [1u16; 7];
for (i, &s) in crop_shape.iter().enumerate() {
new_header.dim[i] = s as u16;
}
new_header.vox_offset = NiftiHeader::default().vox_offset;
let mut affine = new_header.affine();
for row in affine.iter_mut().take(3) {
row[3] += row[0] * crop_offset[0] as f32
+ row[1] * crop_offset[1] as f32
+ row[2] * crop_offset[2] as f32;
}
new_header.set_affine(affine);
let data_len = buffer.len();
Ok(NiftiImage::from_shared_bytes(
new_header,
Arc::new(buffer),
0,
data_len,
))
}
pub struct CropLoaderConfig {
pub patch_size: [usize; 3],
pub patches_per_volume: usize,
pub patch_overlap: [usize; 3],
pub randomize: bool,
}
impl Default for CropLoaderConfig {
fn default() -> Self {
Self {
patch_size: [64, 64, 64],
patches_per_volume: 4,
patch_overlap: [0, 0, 0],
randomize: false,
}
}
}
pub struct CropLoader {
volumes: Vec<PathBuf>,
current_volume: usize,
patches_extracted: usize,
config: CropLoaderConfig,
}
impl CropLoader {
pub fn new<P: AsRef<Path>>(volumes: Vec<P>, config: CropLoaderConfig) -> Self {
Self {
volumes: volumes.iter().map(|p| p.as_ref().to_path_buf()).collect(),
current_volume: 0,
patches_extracted: 0,
config,
}
}
#[allow(unsafe_code)]
pub fn next_patch(&mut self) -> Result<NiftiImage> {
if self.current_volume >= self.volumes.len() {
return Err(Error::Exhausted("all volumes processed".to_string()));
}
if self.config.patches_per_volume == 0 {
return Err(Error::InvalidDimensions(
"patches_per_volume must be positive".to_string(),
));
}
let path = &self.volumes[self.current_volume];
let file = File::open(path)?;
let mmap = unsafe { Mmap::map(&file)? };
let header = NiftiHeader::from_bytes(&mmap)?;
let volume_shape = header.shape();
if volume_shape.len() < 3 {
return Err(Error::InvalidDimensions(
"expected at least 3 spatial dimensions".to_string(),
));
}
for (i, (&dim, &patch_dim)) in volume_shape
.iter()
.zip(self.config.patch_size.iter())
.take(3)
.enumerate()
{
if patch_dim == 0 {
return Err(Error::InvalidDimensions(format!(
"patch_size[{}] must be positive",
i
)));
}
if patch_dim > dim as usize {
return Err(Error::InvalidDimensions(format!(
"patch_size[{}]={} cannot exceed image dimension[{}]={}",
i, patch_dim, i, dim
)));
}
}
for i in 0..3 {
if self.config.patch_overlap[i] >= self.config.patch_size[i] {
return Err(Error::InvalidDimensions(
"patch_overlap must be smaller than patch_size in all dimensions".to_string(),
));
}
}
let patch_positions = if self.config.randomize {
self.random_patch_positions(&volume_shape)
} else {
self.grid_patch_positions(&volume_shape)?
};
if self.patches_extracted >= patch_positions.len() {
self.current_volume += 1;
self.patches_extracted = 0;
return self.next_patch(); }
let patch_offset = patch_positions[self.patches_extracted];
self.patches_extracted += 1;
load_cropped(path, patch_offset, self.config.patch_size)
}
fn grid_patch_positions(&self, shape: &[u16]) -> Result<Vec<[usize; 3]>> {
let [pd, ph, pw] = self.config.patch_size;
let [od, oh, ow] = self.config.patch_overlap;
let step_d = pd.saturating_sub(od);
let step_h = ph.saturating_sub(oh);
let step_w = pw.saturating_sub(ow);
if step_d == 0 || step_h == 0 || step_w == 0 {
return Err(Error::InvalidDimensions(
"patch_size must be larger than patch_overlap in all dimensions".to_string(),
));
}
let mut positions = Vec::new();
let max_d = (shape[0] as usize).saturating_sub(pd);
let max_h = (*shape.get(1).unwrap_or(&1) as usize).saturating_sub(ph);
let max_w = (*shape.get(2).unwrap_or(&1) as usize).saturating_sub(pw);
for d in (0..=max_d).step_by(step_d) {
for h in (0..=max_h).step_by(step_h) {
for w in (0..=max_w).step_by(step_w) {
positions.push([d, h, w]);
}
}
}
Ok(positions)
}
fn random_patch_positions(&self, shape: &[u16]) -> Vec<[usize; 3]> {
use rand::thread_rng;
let max_d =
(*shape.first().unwrap_or(&1) as usize).saturating_sub(self.config.patch_size[0]);
let max_h =
(*shape.get(1).unwrap_or(&1) as usize).saturating_sub(self.config.patch_size[1]);
let max_w =
(*shape.get(2).unwrap_or(&1) as usize).saturating_sub(self.config.patch_size[2]);
let mut rng = thread_rng();
let mut positions = Vec::new();
for _ in 0..self.config.patches_per_volume {
positions.push([
rng.gen_range(0..=max_d),
rng.gen_range(0..=max_h),
rng.gen_range(0..=max_w),
]);
}
positions
}
}
pub struct BatchLoader {
loader: CropLoader,
batch_size: usize,
}
impl BatchLoader {
pub fn new<P: AsRef<Path>>(volumes: Vec<P>, batch_size: usize) -> Self {
let config = CropLoaderConfig::default();
Self {
loader: CropLoader::new(volumes, config),
batch_size,
}
}
pub fn next_batch(&mut self) -> Result<Vec<NiftiImage>> {
if self.batch_size == 0 {
return Err(Error::InvalidDimensions(
"batch_size must be positive".to_string(),
));
}
let mut batch = Vec::with_capacity(self.batch_size);
for _ in 0..self.batch_size {
match self.loader.next_patch() {
Ok(patch) => batch.push(patch),
Err(Error::Exhausted(_)) => break,
Err(e) => return Err(e),
}
}
if batch.is_empty() {
Err(Error::Exhausted("no more patches available".to_string()))
} else {
Ok(batch)
}
}
}
pub struct TrainingDataLoader {
volumes: Vec<PathBuf>,
config: CropLoaderConfig,
cache: HashMap<usize, Vec<NiftiImage>>,
max_cache_size: usize,
current_volume: usize,
current_patch: usize,
prefetch_queue: Vec<(usize, Vec<[usize; 3]>)>,
patches_processed: usize,
}
impl TrainingDataLoader {
pub fn new<P: AsRef<Path>>(
volumes: Vec<P>,
config: CropLoaderConfig,
cache_size: usize,
) -> Result<Self> {
let volumes: Vec<PathBuf> = volumes
.into_iter()
.map(|p| p.as_ref().to_path_buf())
.collect();
if volumes.is_empty() {
return Err(Error::InvalidDimensions("No volumes provided".to_string()));
}
for i in 0..3 {
if config.patch_size[i] == 0 {
return Err(Error::InvalidDimensions(format!(
"patch_size[{}] must be positive",
i
)));
}
}
if config.patches_per_volume == 0 {
return Err(Error::InvalidDimensions(
"patches_per_volume must be positive".to_string(),
));
}
for i in 0..3 {
if config.patch_overlap[i] >= config.patch_size[i] {
return Err(Error::InvalidDimensions(
"patch_overlap must be smaller than patch_size in all dimensions".to_string(),
));
}
}
let mut loader = Self {
cache: HashMap::new(),
max_cache_size: cache_size,
current_volume: 0,
current_patch: 0,
prefetch_queue: Vec::new(),
patches_processed: 0,
volumes,
config,
};
loader.initialize_prefetch()?;
Ok(loader)
}
fn initialize_prefetch(&mut self) -> Result<()> {
let prefetch_count = (self.max_cache_size / self.config.patches_per_volume).min(3);
for i in 0..prefetch_count.min(self.volumes.len()) {
let patch_positions = self.compute_patch_positions(&self.volumes[i])?;
self.prefetch_queue.push((i, patch_positions));
}
Ok(())
}
pub fn next_patch(&mut self) -> Result<NiftiImage> {
if !self.cache.contains_key(&self.current_volume) {
self.load_volume_patches(self.current_volume)?;
}
let patches = self.cache.get_mut(&self.current_volume).ok_or_else(|| {
Error::InvalidDimensions("cache invariant violated: volume should be loaded".into())
})?;
if self.current_patch >= patches.len() {
self.current_volume += 1;
self.current_patch = 0;
if self.current_volume >= self.volumes.len() {
return Err(Error::Exhausted("all patches processed".to_string()));
}
return self.next_patch();
}
let patch = patches.swap_remove(self.current_patch); self.patches_processed += 1;
if self.cache.len() < 3 && self.current_volume + 2 < self.volumes.len() {
self.trigger_prefetch(self.current_volume + 2)?;
}
Ok(patch)
}
fn load_volume_patches(&mut self, volume_idx: usize) -> Result<()> {
let volume_path = &self.volumes[volume_idx];
let patch_positions = self.compute_patch_positions(volume_path)?;
let mut patches = Vec::with_capacity(patch_positions.len());
for position in patch_positions {
let patch = load_cropped(volume_path, position, self.config.patch_size)?;
patches.push(patch);
}
self.cache.insert(volume_idx, patches);
Ok(())
}
fn compute_patch_positions(&self, volume_path: &Path) -> Result<Vec<[usize; 3]>> {
let header = load_header(volume_path)?;
let shape = header.shape();
if shape.len() < 3 {
return Err(Error::InvalidDimensions(
"expected at least 3 spatial dimensions".to_string(),
));
}
for (i, (&dim, &patch_dim)) in shape
.iter()
.zip(self.config.patch_size.iter())
.take(3)
.enumerate()
{
if patch_dim > dim as usize {
return Err(Error::InvalidDimensions(format!(
"patch_size[{}]={} cannot exceed image dimension[{}]={}",
i, patch_dim, i, dim
)));
}
}
if self.config.randomize {
Ok(self.random_patch_positions(shape))
} else {
self.grid_patch_positions(shape)
}
}
fn grid_patch_positions(&self, shape: &[u16]) -> Result<Vec<[usize; 3]>> {
let [pd, ph, pw] = self.config.patch_size;
let [od, oh, ow] = self.config.patch_overlap;
let max_d = (shape[0] as usize).saturating_sub(pd);
let max_h = (*shape.get(1).unwrap_or(&1) as usize).saturating_sub(ph);
let max_w = (*shape.get(2).unwrap_or(&1) as usize).saturating_sub(pw);
let step_d = pd.saturating_sub(od);
let step_h = ph.saturating_sub(oh);
let step_w = pw.saturating_sub(ow);
if step_d == 0 || step_h == 0 || step_w == 0 {
return Err(Error::InvalidDimensions(
"patch_size must be larger than patch_overlap in all dimensions".to_string(),
));
}
let mut positions = Vec::new();
for d in (0..=max_d).step_by(step_d) {
for h in (0..=max_h).step_by(step_h) {
for w in (0..=max_w).step_by(step_w) {
positions.push([d, h, w]);
}
}
}
if positions.len() > self.config.patches_per_volume {
positions.truncate(self.config.patches_per_volume);
} else if positions.len() < self.config.patches_per_volume {
let mut rng = rand::thread_rng();
while positions.len() < self.config.patches_per_volume {
let d = rng.gen_range(0..=max_d);
let h = rng.gen_range(0..=max_h);
let w = rng.gen_range(0..=max_w);
positions.push([d, h, w]);
}
}
Ok(positions)
}
fn random_patch_positions(&self, shape: &[u16]) -> Vec<[usize; 3]> {
let [pd, ph, pw] = self.config.patch_size;
let max_d = (shape[0] as usize).saturating_sub(pd);
let max_h = (*shape.get(1).unwrap_or(&1) as usize).saturating_sub(ph);
let max_w = (*shape.get(2).unwrap_or(&1) as usize).saturating_sub(pw);
let mut rng = rand::thread_rng();
let mut positions = Vec::with_capacity(self.config.patches_per_volume);
for _ in 0..self.config.patches_per_volume {
positions.push([
rng.gen_range(0..=max_d),
rng.gen_range(0..=max_h),
rng.gen_range(0..=max_w),
]);
}
positions
}
fn trigger_prefetch(&mut self, volume_idx: usize) -> Result<()> {
if volume_idx >= self.volumes.len() {
return Ok(());
}
let patch_positions = self.compute_patch_positions(&self.volumes[volume_idx])?;
self.prefetch_queue.push((volume_idx, patch_positions));
Ok(())
}
pub fn stats(&self) -> LoaderStats {
LoaderStats {
total_volumes: self.volumes.len(),
current_volume: self.current_volume,
cached_volumes: self.cache.len(),
patches_processed: self.patches_processed,
cache_size: self.cache.values().map(|patches| patches.len()).sum(),
max_cache_size: self.max_cache_size,
}
}
pub fn reset(&mut self) -> Result<()> {
self.cache.clear();
self.current_volume = 0;
self.current_patch = 0;
self.patches_processed = 0;
self.prefetch_queue.clear();
self.initialize_prefetch()?;
Ok(())
}
pub fn volumes_len(&self) -> usize {
self.volumes.len()
}
pub fn patches_per_volume(&self) -> usize {
self.config.patches_per_volume
}
}
#[derive(Debug, Clone)]
pub struct LoaderStats {
pub total_volumes: usize,
pub current_volume: usize,
pub cached_volumes: usize,
pub patches_processed: usize,
pub cache_size: usize,
pub max_cache_size: usize,
}
impl std::fmt::Display for LoaderStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Training Data Loader Statistics:")?;
writeln!(f, " Total volumes: {}", self.total_volumes)?;
writeln!(
f,
" Current volume: {}/{}",
self.current_volume + 1,
self.total_volumes
)?;
writeln!(f, " Cached volumes: {}", self.cached_volumes)?;
writeln!(f, " Patches processed: {}", self.patches_processed)?;
writeln!(
f,
" Cache utilization: {}/{} patches",
self.cache_size, self.max_cache_size
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::s;
use ndarray::ArrayD;
use ndarray::ShapeBuilder;
use tempfile::tempdir;
fn create_f_order_array(data: Vec<f32>, shape: Vec<usize>) -> ArrayD<f32> {
let c_order = ArrayD::from_shape_vec(shape.clone(), data).unwrap();
let mut f_order = ArrayD::zeros(ndarray::IxDyn(&shape).f());
f_order.assign(&c_order);
f_order
}
#[test]
fn test_roundtrip_uncompressed() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.nii");
let data = create_f_order_array((0..1000).map(|i| i as f32).collect(), vec![10, 10, 10]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data.clone(), affine);
save(&img, &path).unwrap();
let loaded = load(&path).unwrap();
assert_eq!(loaded.shape(), &[10, 10, 10]);
let loaded_data = loaded.to_f32().unwrap();
assert_eq!(
loaded_data.as_slice_memory_order().unwrap(),
data.as_slice_memory_order().unwrap()
);
}
#[test]
fn test_roundtrip_gzipped() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.nii.gz");
let data = create_f_order_array((0..1000).map(|i| i as f32).collect(), vec![10, 10, 10]);
let affine = [
[2.0, 0.0, 0.0, -10.0],
[0.0, 2.0, 0.0, -10.0],
[0.0, 0.0, 2.0, -10.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data.clone(), affine);
save(&img, &path).unwrap();
let loaded = load(&path).unwrap();
assert_eq!(loaded.shape(), &[10, 10, 10]);
assert_eq!(loaded.affine(), affine);
let loaded_data = loaded.to_f32().unwrap();
assert_eq!(
loaded_data.as_slice_memory_order().unwrap(),
data.as_slice_memory_order().unwrap()
);
}
#[test]
fn test_load_cropped_byte_exact() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.nii");
let data = create_f_order_array((0..131072).map(|i| i as f32).collect(), vec![64, 64, 32]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data.clone(), affine);
save(&img, &path).unwrap();
let crop_offset = [16, 16, 8];
let crop_shape = [32, 32, 16];
let cropped = load_cropped(&path, crop_offset, crop_shape).unwrap();
assert_eq!(cropped.shape(), &[32, 32, 16]);
let original_slice = data.slice(s![16..48, 16..48, 8..24]).to_owned();
let cropped_data = cropped.to_f32().unwrap();
for x in 0..32 {
for y in 0..32 {
for z in 0..16 {
let expected = original_slice[[x, y, z]];
let actual = cropped_data[[x, y, z]];
assert!(
(expected - actual).abs() < 1e-5,
"Mismatch at [{},{},{}]: expected {}, got {}",
x,
y,
z,
expected,
actual
);
}
}
}
}
#[test]
fn test_save_cropped_roundtrip() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.nii");
let cropped_path = dir.path().join("cropped.nii");
let data = create_f_order_array((0..131072).map(|i| i as f32).collect(), vec![64, 64, 32]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data.clone(), affine);
save(&img, &path).unwrap();
let crop_offset = [8, 8, 4];
let crop_shape = [16, 16, 8];
let cropped = load_cropped(&path, crop_offset, crop_shape).unwrap();
let cropped_data = cropped.to_f32().unwrap();
save(&cropped, &cropped_path).unwrap();
let loaded = load(&cropped_path).unwrap();
assert_eq!(loaded.shape(), &crop_shape);
let loaded_data = loaded.to_f32().unwrap();
assert_eq!(
loaded_data.as_slice_memory_order().unwrap(),
cropped_data.as_slice_memory_order().unwrap()
);
}
#[test]
fn test_load_cropped_config() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.nii");
let data = create_f_order_array((0..16384).map(|i| i as f32).collect(), vec![32, 32, 16]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let mut img = NiftiImage::from_array(data.clone(), affine);
img.header_mut().pixdim = [0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0];
save(&img, &path).unwrap();
let config = LoadCroppedConfig {
output_shape: [16, 16, 16],
target_spacing: None, target_orientation: None,
output_offset: None,
};
let loaded = load_cropped_config(&path, config).unwrap();
assert_eq!(loaded.shape(), &[16, 16, 16]);
}
#[test]
fn test_training_data_loader() {
let dir = tempdir().unwrap();
let paths = vec![dir.path().join("test1.nii"), dir.path().join("test2.nii")];
for (i, path) in paths.iter().enumerate() {
let size = 64 * 64 * 32;
let data = create_f_order_array(
((i * size)..((i + 1) * size)).map(|v| v as f32).collect(),
vec![64, 64, 32],
);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data, affine);
save(&img, path).unwrap();
}
let config = CropLoaderConfig {
patch_size: [32, 32, 16],
patches_per_volume: 2,
patch_overlap: [0, 0, 0],
randomize: false,
};
let mut loader = TrainingDataLoader::new(paths, config, 100).unwrap();
assert_eq!(loader.stats().total_volumes, 2);
let patch1 = loader.next_patch().unwrap();
assert_eq!(patch1.shape(), &[32, 32, 16]);
let patch2 = loader.next_patch().unwrap();
assert_eq!(patch2.shape(), &[32, 32, 16]);
let patch3 = loader.next_patch().unwrap();
assert_eq!(patch3.shape(), &[32, 32, 16]);
let patch4 = loader.next_patch().unwrap();
assert_eq!(patch4.shape(), &[32, 32, 16]);
let exhausted = loader.next_patch().unwrap_err();
assert!(matches!(exhausted, Error::Exhausted(_)));
let stats = loader.stats();
assert_eq!(stats.patches_processed, 4);
}
#[test]
fn test_training_data_loader_random() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.nii");
let data = create_f_order_array((0..131072).map(|i| i as f32).collect(), vec![64, 64, 32]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data, affine);
save(&img, &path).unwrap();
let config = CropLoaderConfig {
patch_size: [16, 16, 8],
patches_per_volume: 4,
patch_overlap: [0, 0, 0],
randomize: true,
};
let mut loader = TrainingDataLoader::new(vec![&path], config, 50).unwrap();
let patch1 = loader.next_patch().unwrap();
let patch2 = loader.next_patch().unwrap();
let _patch3 = loader.next_patch().unwrap();
let _patch4 = loader.next_patch().unwrap();
assert_eq!(patch1.shape(), &[16, 16, 8]);
assert_eq!(patch2.shape(), &[16, 16, 8]);
assert_eq!(_patch3.shape(), &[16, 16, 8]);
assert_eq!(_patch4.shape(), &[16, 16, 8]);
let data1 = patch1.to_f32().unwrap();
let data2 = patch2.to_f32().unwrap();
let _data3 = _patch3.to_f32().unwrap();
let _data4 = _patch4.to_f32().unwrap();
assert_ne!(data1, data2);
}
#[test]
fn test_crop_loader() {
let dir = tempdir().unwrap();
let paths = vec![dir.path().join("test1.nii"), dir.path().join("test2.nii")];
for (i, path) in paths.iter().enumerate() {
let size = 32 * 32 * 16;
let data = ArrayD::from_shape_vec(
vec![32, 32, 16],
((i * size)..((i + 1) * size)).map(|v| v as f32).collect(),
)
.unwrap();
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data, affine);
save(&img, path).unwrap();
}
let config = CropLoaderConfig {
patch_size: [16, 16, 8],
patches_per_volume: 2,
patch_overlap: [0, 0, 0],
randomize: false,
};
let mut loader = CropLoader::new(paths, config);
let patch1 = loader.next_patch().unwrap();
let patch2 = loader.next_patch().unwrap();
let patch3 = loader.next_patch().unwrap();
let patch4 = loader.next_patch().unwrap();
assert_eq!(patch1.shape(), &[16, 16, 8]);
assert_eq!(patch2.shape(), &[16, 16, 8]);
assert_eq!(patch3.shape(), &[16, 16, 8]);
assert_eq!(patch4.shape(), &[16, 16, 8]);
}
#[test]
fn training_data_loader_rejects_invalid_overlap() {
let dir = tempdir().unwrap();
let path = dir.path().join("test_invalid.nii");
let data = create_f_order_array((0..4096).map(|i| i as f32).collect(), vec![16, 16, 16]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data, affine);
save(&img, &path).unwrap();
let cfg = CropLoaderConfig {
patch_size: [8, 8, 8],
patches_per_volume: 1,
patch_overlap: [8, 4, 4],
randomize: false,
};
let result = TrainingDataLoader::new(vec![path], cfg, 10);
assert!(result.is_err());
}
#[test]
fn training_data_loader_rejects_zero_patch_size() {
let dir = tempdir().unwrap();
let path = dir.path().join("test_zero_patch_size.nii");
let data = create_f_order_array((0..4096).map(|i| i as f32).collect(), vec![16, 16, 16]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data, affine);
save(&img, &path).unwrap();
let cfg = CropLoaderConfig {
patch_size: [0, 8, 8],
patches_per_volume: 1,
patch_overlap: [0, 0, 0],
randomize: false,
};
let result = TrainingDataLoader::new(vec![&path], cfg, 10);
assert!(matches!(result, Err(Error::InvalidDimensions(_))));
}
#[test]
fn training_data_loader_rejects_zero_patches_per_volume() {
let dir = tempdir().unwrap();
let path = dir.path().join("test_zero_patches.nii");
let data = create_f_order_array((0..4096).map(|i| i as f32).collect(), vec![16, 16, 16]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data, affine);
save(&img, &path).unwrap();
let cfg = CropLoaderConfig {
patch_size: [8, 8, 8],
patches_per_volume: 0,
patch_overlap: [0, 0, 0],
randomize: false,
};
let result = TrainingDataLoader::new(vec![&path], cfg, 10);
assert!(matches!(result, Err(Error::InvalidDimensions(_))));
}
#[test]
fn batch_loader_exhausts_cleanly() {
let dir = tempdir().unwrap();
let path = dir.path().join("test_batch.nii");
let data = create_f_order_array((0..262144).map(|i| i as f32).collect(), vec![64, 64, 64]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data, affine);
save(&img, &path).unwrap();
let mut loader = BatchLoader::new(vec![&path], 2);
let batch = loader.next_batch().unwrap();
assert_eq!(batch.len(), 1);
let err = loader.next_batch().unwrap_err();
assert!(matches!(err, Error::Exhausted(_)));
}
#[test]
fn batch_loader_propagates_invalid_patch_error() {
let dir = tempdir().unwrap();
let path = dir.path().join("test_batch_invalid.nii");
let data = create_f_order_array((0..32768).map(|i| i as f32).collect(), vec![32, 32, 32]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data, affine);
save(&img, &path).unwrap();
let mut loader = BatchLoader::new(vec![&path], 1);
let err = loader.next_batch().unwrap_err();
assert!(matches!(err, Error::InvalidDimensions(_)));
}
#[test]
fn test_memory_efficiency() {
let dir = tempdir().unwrap();
let path = dir.path().join("large_test.nii");
let data = create_f_order_array(
(0..(256 * 256 * 64)).map(|i| i as f32).collect(),
vec![256, 256, 64],
);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data.clone(), affine);
save(&img, &path).unwrap();
let crop_offset = [64, 64, 16];
let crop_shape = [64, 64, 32];
let cropped = load_cropped(&path, crop_offset, crop_shape).unwrap();
assert_eq!(cropped.shape(), crop_shape);
let original_slice = data.slice(s![64..128, 64..128, 16..48]).to_owned();
let cropped_data = cropped.to_f32().unwrap();
for x in 0..64 {
for y in 0..64 {
for z in 0..32 {
let expected = original_slice[[x, y, z]];
let actual = cropped_data[[x, y, z]];
assert!(
(expected - actual).abs() < 1e-5,
"Mismatch at [{},{},{}]: expected {}, got {}",
x,
y,
z,
expected,
actual
);
}
}
}
let full_size_bytes = 256 * 256 * 64 * 4; let crop_size_bytes = 64 * 64 * 32 * 4;
assert!(crop_size_bytes < full_size_bytes / 10); }
#[test]
fn test_patch_larger_than_volume_does_not_panic() {
let dir = tempdir().unwrap();
let path = dir.path().join("test_small_volume.nii");
let data = create_f_order_array((0..64).map(|i| i as f32).collect(), vec![4, 4, 4]);
let affine = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
];
let img = NiftiImage::from_array(data, affine);
save(&img, &path).unwrap();
let config = CropLoaderConfig {
patch_size: [8, 8, 8],
patches_per_volume: 2,
patch_overlap: [0, 0, 0],
randomize: false, };
let mut loader = CropLoader::new(vec![&path], config);
let result = loader.next_patch();
assert!(result.is_ok() || result.is_err());
let config_random = CropLoaderConfig {
patch_size: [8, 8, 8],
patches_per_volume: 2,
patch_overlap: [0, 0, 0],
randomize: true,
};
let mut loader_random = CropLoader::new(vec![&path], config_random);
let result_random = loader_random.next_patch();
assert!(result_random.is_ok() || result_random.is_err());
}
}