use std::io::Write;
use std::path::Path;
use anyhow::{Context, Result};
#[derive(Debug, Clone)]
pub struct NiftiConfig {
pub dims: (usize, usize, usize),
pub voxel_size: f32,
pub origin_mm: (f32, f32, f32),
pub compress: bool,
pub smooth_fwhm_mm: f32,
}
impl Default for NiftiConfig {
fn default() -> Self {
Self {
dims: (96, 96, 96),
voxel_size: 2.0,
origin_mm: (90.0, 126.0, 72.0),
compress: true,
smooth_fwhm_mm: 6.0,
}
}
}
fn build_nifti1_header(config: &NiftiConfig) -> [u8; 348] {
let mut hdr = [0u8; 348];
let (nx, ny, nz) = config.dims;
let vs = config.voxel_size;
hdr[0..4].copy_from_slice(&348i32.to_le_bytes());
let dim_off = 40;
hdr[dim_off..dim_off + 2].copy_from_slice(&3i16.to_le_bytes()); hdr[dim_off + 2..dim_off + 4].copy_from_slice(&(nx as i16).to_le_bytes());
hdr[dim_off + 4..dim_off + 6].copy_from_slice(&(ny as i16).to_le_bytes());
hdr[dim_off + 6..dim_off + 8].copy_from_slice(&(nz as i16).to_le_bytes());
hdr[dim_off + 8..dim_off + 10].copy_from_slice(&1i16.to_le_bytes()); hdr[dim_off + 10..dim_off + 12].copy_from_slice(&1i16.to_le_bytes()); hdr[dim_off + 12..dim_off + 14].copy_from_slice(&1i16.to_le_bytes()); hdr[dim_off + 14..dim_off + 16].copy_from_slice(&1i16.to_le_bytes());
let datatype_off = 70;
hdr[datatype_off..datatype_off + 2].copy_from_slice(&16i16.to_le_bytes());
let bitpix_off = 72;
hdr[bitpix_off..bitpix_off + 2].copy_from_slice(&32i16.to_le_bytes());
let pixdim_off = 76;
hdr[pixdim_off..pixdim_off + 4].copy_from_slice(&1.0f32.to_le_bytes()); hdr[pixdim_off + 4..pixdim_off + 8].copy_from_slice(&vs.to_le_bytes()); hdr[pixdim_off + 8..pixdim_off + 12].copy_from_slice(&vs.to_le_bytes()); hdr[pixdim_off + 12..pixdim_off + 16].copy_from_slice(&vs.to_le_bytes());
let vox_offset_off = 108;
hdr[vox_offset_off..vox_offset_off + 4].copy_from_slice(&352.0f32.to_le_bytes());
let scl_slope_off = 112;
hdr[scl_slope_off..scl_slope_off + 4].copy_from_slice(&1.0f32.to_le_bytes());
let scl_inter_off = 116;
hdr[scl_inter_off..scl_inter_off + 4].copy_from_slice(&0.0f32.to_le_bytes());
let qform_off = 252;
hdr[qform_off..qform_off + 2].copy_from_slice(&1i16.to_le_bytes());
let sform_off = 254;
hdr[sform_off..sform_off + 2].copy_from_slice(&4i16.to_le_bytes());
let (ox, oy, oz) = config.origin_mm;
let srow_x_off = 280;
hdr[srow_x_off..srow_x_off + 4].copy_from_slice(&vs.to_le_bytes());
hdr[srow_x_off + 4..srow_x_off + 8].copy_from_slice(&0.0f32.to_le_bytes());
hdr[srow_x_off + 8..srow_x_off + 12].copy_from_slice(&0.0f32.to_le_bytes());
hdr[srow_x_off + 12..srow_x_off + 16].copy_from_slice(&(-ox).to_le_bytes());
let srow_y_off = 296;
hdr[srow_y_off..srow_y_off + 4].copy_from_slice(&0.0f32.to_le_bytes());
hdr[srow_y_off + 4..srow_y_off + 8].copy_from_slice(&vs.to_le_bytes());
hdr[srow_y_off + 8..srow_y_off + 12].copy_from_slice(&0.0f32.to_le_bytes());
hdr[srow_y_off + 12..srow_y_off + 16].copy_from_slice(&(-oy).to_le_bytes());
let srow_z_off = 312;
hdr[srow_z_off..srow_z_off + 4].copy_from_slice(&0.0f32.to_le_bytes());
hdr[srow_z_off + 4..srow_z_off + 8].copy_from_slice(&0.0f32.to_le_bytes());
hdr[srow_z_off + 8..srow_z_off + 12].copy_from_slice(&vs.to_le_bytes());
hdr[srow_z_off + 12..srow_z_off + 16].copy_from_slice(&(-oz).to_le_bytes());
let magic_off = 344;
hdr[magic_off..magic_off + 4].copy_from_slice(b"n+1\0");
hdr
}
pub fn mni_to_voxel(
coords: &[f32],
config: &NiftiConfig,
) -> Vec<Option<(usize, usize, usize)>> {
let (nx, ny, nz) = config.dims;
let vs = config.voxel_size;
let (ox, oy, oz) = config.origin_mm;
let n_vertices = coords.len() / 3;
(0..n_vertices)
.map(|i| {
let x = coords[i * 3];
let y = coords[i * 3 + 1];
let z = coords[i * 3 + 2];
let vi = ((x + ox) / vs).round() as isize;
let vj = ((y + oy) / vs).round() as isize;
let vk = ((z + oz) / vs).round() as isize;
if vi >= 0 && vi < nx as isize && vj >= 0 && vj < ny as isize && vk >= 0 && vk < nz as isize {
Some((vi as usize, vj as usize, vk as usize))
} else {
None
}
})
.collect()
}
pub fn surface_to_volume(
vertex_values: &[f32],
vertex_coords: &[f32],
config: &NiftiConfig,
) -> Vec<f32> {
let (nx, ny, nz) = config.dims;
let n_voxels = nx * ny * nz;
let mut volume = vec![0.0f32; n_voxels];
let mut counts = vec![0u32; n_voxels];
let voxels = mni_to_voxel(vertex_coords, config);
for (vi, vox) in voxels.iter().enumerate() {
if vi >= vertex_values.len() {
break;
}
if let Some((i, j, k)) = *vox {
let idx = i + j * nx + k * nx * ny;
volume[idx] += vertex_values[vi];
counts[idx] += 1;
}
}
for i in 0..n_voxels {
if counts[i] > 1 {
volume[i] /= counts[i] as f32;
}
}
if config.smooth_fwhm_mm > 0.0 {
let sigma_voxels = config.smooth_fwhm_mm / (2.355 * config.voxel_size);
volume = gaussian_smooth_3d_masked(&volume, &counts, nx, ny, nz, sigma_voxels);
}
volume
}
fn gaussian_smooth_3d_masked(
volume: &[f32],
scatter_counts: &[u32],
nx: usize, ny: usize, nz: usize,
sigma: f32,
) -> Vec<f32> {
let radius = (3.0 * sigma).ceil() as isize;
let n_voxels = nx * ny * nz;
let indicator: Vec<f32> = scatter_counts.iter()
.map(|&c| if c > 0 { 1.0 } else { 0.0 })
.collect();
let ksize = (2 * radius + 1) as usize;
let mut kernel = vec![0.0f32; ksize];
for i in 0..ksize {
let x = i as f32 - radius as f32;
kernel[i] = (-0.5 * (x / sigma) * (x / sigma)).exp();
}
let smooth_signal = separable_convolve_3d(volume, nx, ny, nz, &kernel, radius);
let smooth_indicator = separable_convolve_3d(&indicator, nx, ny, nz, &kernel, radius);
let mut mask = vec![false; n_voxels];
for k in 0..nz {
for j in 0..ny {
for i in 0..nx {
if scatter_counts[i + j * nx + k * nx * ny] > 0 {
let r = radius;
for dk in -r..=r {
let kk = k as isize + dk;
if kk < 0 || kk >= nz as isize { continue; }
for dj in -r..=r {
let jj = j as isize + dj;
if jj < 0 || jj >= ny as isize { continue; }
for di in -r..=r {
let ii = i as isize + di;
if ii < 0 || ii >= nx as isize { continue; }
if (di*di + dj*dj + dk*dk) as f32 <= (r*r) as f32 {
mask[ii as usize + jj as usize * nx + kk as usize * nx * ny] = true;
}
}
}
}
}
}
}
}
let mut result = vec![0.0f32; n_voxels];
for i in 0..n_voxels {
if mask[i] && smooth_indicator[i] > 1e-8 {
result[i] = smooth_signal[i] / smooth_indicator[i];
}
}
result
}
fn separable_convolve_3d(
input: &[f32],
nx: usize, ny: usize, nz: usize,
kernel: &[f32],
radius: isize,
) -> Vec<f32> {
let n = nx * ny * nz;
let ksize = kernel.len();
let mut buf = input.to_vec();
let mut tmp = vec![0.0f32; n];
for k in 0..nz {
for j in 0..ny {
for i in 0..nx {
let mut sum = 0.0f32;
for ki in 0..ksize {
let ii = i as isize + ki as isize - radius;
if ii >= 0 && ii < nx as isize {
sum += buf[ii as usize + j * nx + k * nx * ny] * kernel[ki];
}
}
tmp[i + j * nx + k * nx * ny] = sum;
}
}
}
std::mem::swap(&mut buf, &mut tmp);
for k in 0..nz {
for j in 0..ny {
for i in 0..nx {
let mut sum = 0.0f32;
for ki in 0..ksize {
let jj = j as isize + ki as isize - radius;
if jj >= 0 && jj < ny as isize {
sum += buf[i + jj as usize * nx + k * nx * ny] * kernel[ki];
}
}
tmp[i + j * nx + k * nx * ny] = sum;
}
}
}
std::mem::swap(&mut buf, &mut tmp);
for k in 0..nz {
for j in 0..ny {
for i in 0..nx {
let mut sum = 0.0f32;
for ki in 0..ksize {
let kk = k as isize + ki as isize - radius;
if kk >= 0 && kk < nz as isize {
sum += buf[i + j * nx + kk as usize * nx * ny] * kernel[ki];
}
}
tmp[i + j * nx + k * nx * ny] = sum;
}
}
}
tmp
}
pub fn get_mesh_coords(brain: &crate::plotting::BrainMesh) -> Vec<f32> {
let mut coords = brain.left.mesh.coords.clone();
coords.extend_from_slice(&brain.right.mesh.coords);
coords
}
pub fn write_nifti(
path: &Path,
volume: &[f32],
config: &NiftiConfig,
) -> Result<()> {
let (nx, ny, nz) = config.dims;
let expected = nx * ny * nz;
if volume.len() != expected {
anyhow::bail!(
"Volume has {} voxels, expected {} ({}×{}×{})",
volume.len(), expected, nx, ny, nz
);
}
let header = build_nifti1_header(config);
let extension = [0u8; 4];
let data_bytes: Vec<u8> = volume.iter().flat_map(|v| v.to_le_bytes()).collect();
let is_gz = path.to_string_lossy().ends_with(".gz");
if is_gz || config.compress {
let file = std::fs::File::create(path)
.with_context(|| format!("failed to create {}", path.display()))?;
let mut gz = flate2::write::GzEncoder::new(file, flate2::Compression::default());
gz.write_all(&header)?;
gz.write_all(&extension)?;
gz.write_all(&data_bytes)?;
gz.finish()?;
} else {
let mut file = std::fs::File::create(path)
.with_context(|| format!("failed to create {}", path.display()))?;
file.write_all(&header)?;
file.write_all(&extension)?;
file.write_all(&data_bytes)?;
}
Ok(())
}
pub fn write_nifti_4d(
path: &Path,
predictions: &[Vec<f32>],
vertex_coords: &[f32],
config: &NiftiConfig,
) -> Result<()> {
let (nx, ny, nz) = config.dims;
let nt = predictions.len();
if nt == 0 {
anyhow::bail!("No timesteps to write");
}
let mut hdr = build_nifti1_header(config);
let dim_off = 40;
hdr[dim_off..dim_off + 2].copy_from_slice(&4i16.to_le_bytes());
hdr[dim_off + 8..dim_off + 10].copy_from_slice(&(nt as i16).to_le_bytes());
let pixdim_off = 76;
hdr[pixdim_off + 16..pixdim_off + 20].copy_from_slice(&0.5f32.to_le_bytes());
let xyzt_off = 123;
hdr[xyzt_off] = 10;
let extension = [0u8; 4];
let mut all_data: Vec<u8> = Vec::with_capacity(nx * ny * nz * nt * 4);
for (ti, pred) in predictions.iter().enumerate() {
let vol = surface_to_volume(pred, vertex_coords, config);
all_data.extend(vol.iter().flat_map(|v| v.to_le_bytes()));
if (ti + 1) % 50 == 0 {
eprintln!(" NIfTI: projected {}/{} timesteps", ti + 1, nt);
}
}
let is_gz = path.to_string_lossy().ends_with(".gz");
if is_gz || config.compress {
let file = std::fs::File::create(path)
.with_context(|| format!("failed to create {}", path.display()))?;
let mut gz = flate2::write::GzEncoder::new(file, flate2::Compression::default());
gz.write_all(&hdr)?;
gz.write_all(&extension)?;
gz.write_all(&all_data)?;
gz.finish()?;
} else {
let mut file = std::fs::File::create(path)
.with_context(|| format!("failed to create {}", path.display()))?;
file.write_all(&hdr)?;
file.write_all(&extension)?;
file.write_all(&all_data)?;
}
eprintln!(
"NIfTI written: {} ({}×{}×{}×{}, {:.1} MB)",
path.display(), nx, ny, nz, nt,
all_data.len() as f64 / 1e6
);
Ok(())
}
pub fn load_pial_coords_mni(
mesh: &str,
base_path: Option<&str>,
) -> Result<Vec<f32>> {
let mesh_dir = crate::fsaverage::find_fsaverage_dir(mesh, base_path)
.ok_or_else(|| anyhow::anyhow!("Could not find {} mesh for NIfTI projection", mesh))?;
let surf_dir = if mesh_dir.join("surf").exists() {
mesh_dir.join("surf")
} else {
mesh_dir.clone()
};
let (lh_coords, _, _, _) = crate::fsaverage::read_freesurfer_surface(&surf_dir.join("lh.pial"))?;
let (rh_coords, _, _, _) = crate::fsaverage::read_freesurfer_surface(&surf_dir.join("rh.pial"))?;
let mut coords = lh_coords;
coords.extend_from_slice(&rh_coords);
Ok(coords)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nifti_header_size() {
let config = NiftiConfig::default();
let hdr = build_nifti1_header(&config);
assert_eq!(hdr.len(), 348);
assert_eq!(i32::from_le_bytes([hdr[0], hdr[1], hdr[2], hdr[3]]), 348);
assert_eq!(&hdr[344..348], b"n+1\0");
}
#[test]
fn test_mni_to_voxel() {
let config = NiftiConfig::default();
let coords = vec![0.0, 0.0, 0.0];
let voxels = mni_to_voxel(&coords, &config);
assert_eq!(voxels.len(), 1);
let (i, j, k) = voxels[0].unwrap();
assert_eq!(i, 45); assert_eq!(j, 63); assert_eq!(k, 36); }
#[test]
fn test_surface_to_volume_basic() {
let config = NiftiConfig {
dims: (10, 10, 10),
voxel_size: 1.0,
origin_mm: (5.0, 5.0, 5.0),
compress: false,
smooth_fwhm_mm: 0.0, };
let coords = vec![0.0, 0.0, 0.0];
let values = vec![42.0];
let vol = surface_to_volume(&values, &coords, &config);
assert_eq!(vol.len(), 1000);
let idx = 5 + 5 * 10 + 5 * 100;
assert_eq!(vol[idx], 42.0);
}
#[test]
fn test_surface_to_volume_smoothed() {
let config = NiftiConfig {
dims: (20, 20, 20),
voxel_size: 1.0,
origin_mm: (10.0, 10.0, 10.0),
compress: false,
smooth_fwhm_mm: 3.0,
};
let coords = vec![0.0, 0.0, 0.0];
let values = vec![42.0];
let vol = surface_to_volume(&values, &coords, &config);
let idx_center = 10 + 10 * 20 + 10 * 20 * 20;
assert!(vol[idx_center] > 0.0, "center voxel should be > 0");
let idx_neighbor = 11 + 10 * 20 + 10 * 20 * 20;
assert!(vol[idx_neighbor] > 0.0, "neighbor should be > 0 after smoothing");
let idx_far = 0 + 0 * 20 + 0 * 20 * 20;
assert_eq!(vol[idx_far], 0.0, "far voxel should be 0");
}
#[test]
fn test_write_nifti_roundtrip() {
let config = NiftiConfig {
dims: (4, 4, 4),
voxel_size: 2.0,
origin_mm: (4.0, 4.0, 4.0),
compress: false,
smooth_fwhm_mm: 0.0,
};
let volume = vec![1.0f32; 64];
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.nii");
write_nifti(&path, &volume, &config).unwrap();
let data = std::fs::read(&path).unwrap();
assert_eq!(data.len(), 348 + 4 + 256);
}
}