use std::collections::HashMap;
use scirs2_core::ndarray::{Array1, Array3, ArrayView3};
use crate::error::{NdimageError, NdimageResult};
#[derive(Debug, Clone, PartialEq)]
pub struct DicomHeader {
pub patient_id: String,
pub modality: String,
pub pixel_spacing: (f64, f64),
pub slice_thickness: f64,
pub rows: usize,
pub columns: usize,
pub num_slices: usize,
pub rescale_slope: f64,
pub rescale_intercept: f64,
pub series_description: String,
pub image_orientation: [f64; 6],
pub image_position: [f64; 3],
}
impl DicomHeader {
pub fn new_ct(
patient_id: impl Into<String>,
rows: usize,
columns: usize,
num_slices: usize,
) -> Self {
Self {
patient_id: patient_id.into(),
modality: "CT".to_string(),
pixel_spacing: (1.0, 1.0),
slice_thickness: 1.0,
rows,
columns,
num_slices,
rescale_slope: 1.0,
rescale_intercept: -1024.0,
series_description: String::new(),
image_orientation: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
image_position: [0.0, 0.0, 0.0],
}
}
pub fn new_mr(
patient_id: impl Into<String>,
rows: usize,
columns: usize,
num_slices: usize,
) -> Self {
Self {
patient_id: patient_id.into(),
modality: "MR".to_string(),
pixel_spacing: (1.0, 1.0),
slice_thickness: 1.0,
rows,
columns,
num_slices,
rescale_slope: 1.0,
rescale_intercept: 0.0,
series_description: String::new(),
image_orientation: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
image_position: [0.0, 0.0, 0.0],
}
}
pub fn voxel_spacing(&self) -> [f64; 3] {
[
self.slice_thickness,
self.pixel_spacing.0,
self.pixel_spacing.1,
]
}
pub fn validate(&self) -> NdimageResult<()> {
if self.rows == 0 || self.columns == 0 || self.num_slices == 0 {
return Err(NdimageError::InvalidInput(
"DicomHeader: rows, columns, and num_slices must be > 0".to_string(),
));
}
if self.slice_thickness <= 0.0 {
return Err(NdimageError::InvalidInput(
"DicomHeader: slice_thickness must be positive".to_string(),
));
}
if self.pixel_spacing.0 <= 0.0 || self.pixel_spacing.1 <= 0.0 {
return Err(NdimageError::InvalidInput(
"DicomHeader: pixel_spacing values must be positive".to_string(),
));
}
if self.rescale_slope == 0.0 {
return Err(NdimageError::InvalidInput(
"DicomHeader: rescale_slope must not be zero".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MedicalVolume {
pub data: Array3<f64>,
pub spacing: [f64; 3],
pub direction: [f64; 9],
pub origin: [f64; 3],
pub header: Option<DicomHeader>,
}
impl MedicalVolume {
pub fn new(
data: Array3<f64>,
spacing: [f64; 3],
origin: [f64; 3],
) -> NdimageResult<Self> {
for (i, &s) in spacing.iter().enumerate() {
if s <= 0.0 {
return Err(NdimageError::InvalidInput(format!(
"MedicalVolume: spacing[{}] must be positive, got {}",
i, s
)));
}
}
Ok(Self {
data,
spacing,
direction: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
origin,
header: None,
})
}
pub fn with_header(mut self, header: DicomHeader) -> Self {
self.header = Some(header);
self
}
pub fn shape(&self) -> (usize, usize, usize) {
let s = self.data.shape();
(s[0], s[1], s[2])
}
pub fn num_voxels(&self) -> usize {
let (nz, ny, nx) = self.shape();
nz * ny * nx
}
pub fn physical_volume_mm3(&self) -> f64 {
self.num_voxels() as f64 * self.spacing[0] * self.spacing[1] * self.spacing[2]
}
pub fn voxel_to_physical(&self, iz: usize, iy: usize, ix: usize) -> [f64; 3] {
[
self.origin[0] + iz as f64 * self.spacing[0],
self.origin[1] + iy as f64 * self.spacing[1],
self.origin[2] + ix as f64 * self.spacing[2],
]
}
pub fn axial_slice(&self, iz: usize) -> NdimageResult<Vec<f64>> {
let (nz, ny, nx) = self.shape();
if iz >= nz {
return Err(NdimageError::InvalidInput(format!(
"MedicalVolume: slice index {} out of range [0, {})",
iz, nz
)));
}
let mut out = Vec::with_capacity(ny * nx);
for iy in 0..ny {
for ix in 0..nx {
out.push(self.data[[iz, iy, ix]]);
}
}
Ok(out)
}
pub fn view(&self) -> ArrayView3<f64> {
self.data.view()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct WindowLeveling {
pub level: f64,
pub window: f64,
}
impl WindowLeveling {
pub fn ct_abdomen() -> Self {
Self { level: 60.0, window: 400.0 }
}
pub fn ct_lung() -> Self {
Self { level: -600.0, window: 1500.0 }
}
pub fn ct_bone() -> Self {
Self { level: 400.0, window: 1800.0 }
}
pub fn ct_brain() -> Self {
Self { level: 40.0, window: 80.0 }
}
pub fn new(level: f64, window: f64) -> NdimageResult<Self> {
if window <= 0.0 {
return Err(NdimageError::InvalidInput(
"WindowLeveling: window must be positive".to_string(),
));
}
Ok(Self { level, window })
}
pub fn apply(&self, value: f64) -> f64 {
let low = self.level - self.window * 0.5;
let high = self.level + self.window * 0.5;
if value <= low {
0.0
} else if value >= high {
1.0
} else {
(value - low) / self.window
}
}
pub fn apply_to_volume(&self, volume: &Array3<f64>) -> Array3<f64> {
volume.mapv(|v| self.apply(v))
}
pub fn bounds(&self) -> (f64, f64) {
(
self.level - self.window * 0.5,
self.level + self.window * 0.5,
)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct HounsfieldUnits;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Tissue {
Air,
Lung,
Fat,
SoftTissue,
Blood,
CancellousBone,
CorticalBone,
Metal,
}
impl HounsfieldUnits {
pub fn rescale(stored_value: f64, slope: f64, intercept: f64) -> f64 {
slope * stored_value + intercept
}
pub fn classify(hu: f64) -> Tissue {
if hu > 3000.0 {
Tissue::Metal
} else if hu > 700.0 {
Tissue::CorticalBone
} else if hu > 300.0 {
Tissue::CancellousBone
} else if hu > 100.0 {
Tissue::Blood
} else if hu > -100.0 {
Tissue::SoftTissue
} else if hu > -500.0 {
Tissue::Fat
} else if hu > -950.0 {
Tissue::Lung
} else {
Tissue::Air
}
}
pub fn classify_volume(
raw: &Array3<f64>,
slope: f64,
intercept: f64,
) -> (Array3<f64>, Array3<u8>) {
let hu_vol = raw.mapv(|v| Self::rescale(v, slope, intercept));
let class_vol = hu_vol.mapv(|hu| Self::classify(hu) as u8);
(hu_vol, class_vol)
}
pub fn tissue_range(tissue: Tissue) -> (f64, f64) {
match tissue {
Tissue::Air => (f64::NEG_INFINITY, -950.0),
Tissue::Lung => (-950.0, -500.0),
Tissue::Fat => (-500.0, -100.0),
Tissue::SoftTissue => (-100.0, 100.0),
Tissue::Blood => (100.0, 300.0),
Tissue::CancellousBone => (300.0, 700.0),
Tissue::CorticalBone => (700.0, 3000.0),
Tissue::Metal => (3000.0, f64::INFINITY),
}
}
}
#[derive(Debug, Clone)]
pub struct N4Config {
pub poly_degree: usize,
pub max_iterations: usize,
pub convergence_threshold: f64,
pub mask_fraction: f64,
}
impl Default for N4Config {
fn default() -> Self {
Self {
poly_degree: 2,
max_iterations: 50,
convergence_threshold: 1e-4,
mask_fraction: 0.02,
}
}
}
pub struct N4BiasCorrection {
config: N4Config,
}
impl N4BiasCorrection {
pub fn new() -> Self {
Self { config: N4Config::default() }
}
pub fn with_config(config: N4Config) -> Self {
Self { config }
}
pub fn correct(&self, volume: &MedicalVolume) -> NdimageResult<(MedicalVolume, Array3<f64>)> {
let data = &volume.data;
let shape = data.shape();
if shape[0] == 0 || shape[1] == 0 || shape[2] == 0 {
return Err(NdimageError::InvalidInput(
"N4BiasCorrection: volume must not be empty".to_string(),
));
}
let (nz, ny, nx) = (shape[0], shape[1], shape[2]);
let n = nz * ny * nx;
let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
let shift = if min_val <= 0.0 { 1.0 - min_val } else { 0.0 };
let mut log_signal = vec![0.0_f64; n];
let mut mask = vec![false; n];
let global_sum: f64 = data.iter().sum();
let global_mean = global_sum / n as f64;
let threshold = self.config.mask_fraction * (global_mean + shift);
for iz in 0..nz {
for iy in 0..ny {
for ix in 0..nx {
let idx = iz * ny * nx + iy * nx + ix;
let v = data[[iz, iy, ix]] + shift;
log_signal[idx] = v.ln();
mask[idx] = v > threshold;
}
}
}
let degree = self.config.poly_degree.min(4);
let mut corrected_log = log_signal.clone();
let mut prev_rmse = f64::INFINITY;
for _iter in 0..self.config.max_iterations {
let bias_log = self.fit_polynomial_bias(&corrected_log, &mask, nz, ny, nx, degree)?;
let mut new_corrected = vec![0.0_f64; n];
let mut sq_sum = 0.0;
let mut cnt = 0usize;
for idx in 0..n {
new_corrected[idx] = corrected_log[idx] - bias_log[idx];
if mask[idx] {
let d = new_corrected[idx] - corrected_log[idx];
sq_sum += d * d;
cnt += 1;
}
}
let rmse = if cnt > 0 { (sq_sum / cnt as f64).sqrt() } else { 0.0 };
corrected_log = new_corrected;
let rel_change = (prev_rmse - rmse).abs() / (prev_rmse + 1e-12);
if rel_change < self.config.convergence_threshold {
break;
}
prev_rmse = rmse;
}
let mut corrected_data = Array3::<f64>::zeros((nz, ny, nx));
let mut bias_field = Array3::<f64>::zeros((nz, ny, nx));
for iz in 0..nz {
for iy in 0..ny {
for ix in 0..nx {
let idx = iz * ny * nx + iy * nx + ix;
let corrected_val = corrected_log[idx].exp() - shift;
let bias_val = (log_signal[idx] - corrected_log[idx]).exp();
corrected_data[[iz, iy, ix]] = corrected_val;
bias_field[[iz, iy, ix]] = bias_val;
}
}
}
let new_volume = MedicalVolume {
data: corrected_data,
spacing: volume.spacing,
direction: volume.direction,
origin: volume.origin,
header: volume.header.clone(),
};
Ok((new_volume, bias_field))
}
fn fit_polynomial_bias(
&self,
log_signal: &[f64],
mask: &[bool],
nz: usize,
ny: usize,
nx: usize,
degree: usize,
) -> NdimageResult<Vec<f64>> {
let basis_fns = polynomial_basis_3d(degree);
let nb = basis_fns.len();
let n = nz * ny * nx;
let masked_indices: Vec<usize> = (0..n).filter(|&i| mask[i]).collect();
let nm = masked_indices.len();
if nm < nb {
return Ok(vec![0.0; n]);
}
let mut a_mat = vec![vec![0.0_f64; nb]; nm];
let mut b_vec = vec![0.0_f64; nm];
for (row, &idx) in masked_indices.iter().enumerate() {
let iz = idx / (ny * nx);
let rem = idx % (ny * nx);
let iy = rem / nx;
let ix = rem % nx;
let zn = 2.0 * iz as f64 / (nz as f64 - 1.0).max(1.0) - 1.0;
let yn = 2.0 * iy as f64 / (ny as f64 - 1.0).max(1.0) - 1.0;
let xn = 2.0 * ix as f64 / (nx as f64 - 1.0).max(1.0) - 1.0;
for (col, &(pz, py, px)) in basis_fns.iter().enumerate() {
a_mat[row][col] = zn.powi(pz as i32) * yn.powi(py as i32) * xn.powi(px as i32);
}
b_vec[row] = log_signal[idx];
}
let coeffs = solve_normal_equations(&a_mat, &b_vec, nb)?;
let mut bias = vec![0.0_f64; n];
for iz in 0..nz {
for iy in 0..ny {
for ix in 0..nx {
let idx = iz * ny * nx + iy * nx + ix;
let zn = 2.0 * iz as f64 / (nz as f64 - 1.0).max(1.0) - 1.0;
let yn = 2.0 * iy as f64 / (ny as f64 - 1.0).max(1.0) - 1.0;
let xn = 2.0 * ix as f64 / (nx as f64 - 1.0).max(1.0) - 1.0;
let mut val = 0.0;
for (j, &(pz, py, px)) in basis_fns.iter().enumerate() {
val += coeffs[j]
* zn.powi(pz as i32)
* yn.powi(py as i32)
* xn.powi(px as i32);
}
bias[idx] = val;
}
}
}
Ok(bias)
}
}
fn polynomial_basis_3d(degree: usize) -> Vec<(usize, usize, usize)> {
let mut basis = Vec::new();
for total in 0..=degree {
for pz in 0..=total {
for py in 0..=(total - pz) {
let px = total - pz - py;
basis.push((pz, py, px));
}
}
}
basis
}
fn solve_normal_equations(
a: &[Vec<f64>],
b: &[f64],
nb: usize,
) -> NdimageResult<Vec<f64>> {
let nm = a.len();
let mut ata = vec![vec![0.0_f64; nb]; nb];
let mut atb = vec![0.0_f64; nb];
for row in 0..nm {
for i in 0..nb {
atb[i] += a[row][i] * b[row];
for j in 0..nb {
ata[i][j] += a[row][i] * a[row][j];
}
}
}
let mut aug: Vec<Vec<f64>> = (0..nb)
.map(|i| {
let mut r = ata[i].clone();
r.push(atb[i]);
r
})
.collect();
for col in 0..nb {
let mut max_row = col;
let mut max_val = aug[col][col].abs();
for row in (col + 1)..nb {
if aug[row][col].abs() > max_val {
max_val = aug[row][col].abs();
max_row = row;
}
}
if max_val < 1e-15 {
return Err(NdimageError::ComputationError(
"N4BiasCorrection: singular normal equations — bias field underdetermined".to_string(),
));
}
aug.swap(col, max_row);
let pivot = aug[col][col];
for j in col..=nb {
aug[col][j] /= pivot;
}
for row in 0..nb {
if row != col {
let factor = aug[row][col];
for j in col..=nb {
let val = aug[col][j] * factor;
aug[row][j] -= val;
}
}
}
}
let coeffs: Vec<f64> = (0..nb).map(|i| aug[i][nb]).collect();
Ok(coeffs)
}
#[derive(Debug, Clone)]
pub struct TissueClassStats {
pub tissue: Tissue,
pub voxel_count: usize,
pub mean: f64,
pub std_dev: f64,
pub min: f64,
pub max: f64,
pub p5: f64,
pub p25: f64,
pub median: f64,
pub p75: f64,
pub p95: f64,
}
#[derive(Debug, Clone)]
pub struct VolumeStats {
pub by_tissue: HashMap<String, TissueClassStats>,
pub global_mean: f64,
pub global_std: f64,
pub total_voxels: usize,
}
impl VolumeStats {
pub fn compute_ct(volume: &MedicalVolume) -> NdimageResult<Self> {
let data = &volume.data;
let n = data.len();
if n == 0 {
return Err(NdimageError::InvalidInput(
"VolumeStats: volume must not be empty".to_string(),
));
}
let voxels: Vec<f64> = data.iter().cloned().collect();
let global_mean = voxels.iter().sum::<f64>() / n as f64;
let global_var = voxels.iter().map(|v| (v - global_mean).powi(2)).sum::<f64>() / n as f64;
let global_std = global_var.sqrt();
let mut by_class: HashMap<String, Vec<f64>> = HashMap::new();
for &v in &voxels {
let t = HounsfieldUnits::classify(v);
by_class.entry(format!("{:?}", t)).or_default().push(v);
}
let mut by_tissue = HashMap::new();
for (name, mut vals) in by_class {
let tissue = tissue_from_str(&name);
let stats = compute_tissue_stats(tissue, &mut vals);
by_tissue.insert(name, stats);
}
Ok(Self {
by_tissue,
global_mean,
global_std,
total_voxels: n,
})
}
pub fn compute_masked(volume: &MedicalVolume, mask: &Array3<bool>) -> NdimageResult<Self> {
let data = &volume.data;
let shape = data.shape();
if shape != mask.shape() {
return Err(NdimageError::DimensionError(
"VolumeStats: volume and mask shapes must match".to_string(),
));
}
let masked_voxels: Vec<f64> = data
.iter()
.zip(mask.iter())
.filter_map(|(&v, &m)| if m { Some(v) } else { None })
.collect();
let n = masked_voxels.len();
if n == 0 {
return Err(NdimageError::InvalidInput(
"VolumeStats: mask selects zero voxels".to_string(),
));
}
let global_mean = masked_voxels.iter().sum::<f64>() / n as f64;
let global_var = masked_voxels
.iter()
.map(|v| (v - global_mean).powi(2))
.sum::<f64>()
/ n as f64;
let global_std = global_var.sqrt();
let mut vals = masked_voxels.clone();
let stats = compute_tissue_stats(Tissue::SoftTissue, &mut vals);
let mut by_tissue = HashMap::new();
by_tissue.insert("Masked".to_string(), stats);
Ok(Self {
by_tissue,
global_mean,
global_std,
total_voxels: n,
})
}
}
fn compute_tissue_stats(tissue: Tissue, vals: &mut Vec<f64>) -> TissueClassStats {
let n = vals.len();
if n == 0 {
return TissueClassStats {
tissue,
voxel_count: 0,
mean: 0.0,
std_dev: 0.0,
min: 0.0,
max: 0.0,
p5: 0.0,
p25: 0.0,
median: 0.0,
p75: 0.0,
p95: 0.0,
};
}
vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mean = vals.iter().sum::<f64>() / n as f64;
let var = vals.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n as f64;
let std_dev = var.sqrt();
let percentile = |p: f64| -> f64 {
let idx_f = p / 100.0 * (n - 1) as f64;
let lo = idx_f.floor() as usize;
let hi = (lo + 1).min(n - 1);
let frac = idx_f - lo as f64;
vals[lo] * (1.0 - frac) + vals[hi] * frac
};
TissueClassStats {
tissue,
voxel_count: n,
mean,
std_dev,
min: vals[0],
max: vals[n - 1],
p5: percentile(5.0),
p25: percentile(25.0),
median: percentile(50.0),
p75: percentile(75.0),
p95: percentile(95.0),
}
}
fn tissue_from_str(s: &str) -> Tissue {
match s {
"Air" => Tissue::Air,
"Lung" => Tissue::Lung,
"Fat" => Tissue::Fat,
"Blood" => Tissue::Blood,
"CancellousBone" => Tissue::CancellousBone,
"CorticalBone" => Tissue::CorticalBone,
"Metal" => Tissue::Metal,
_ => Tissue::SoftTissue,
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array3;
#[test]
fn test_dicom_header_validation() {
let h = DicomHeader::new_ct("P001", 512, 512, 64);
assert!(h.validate().is_ok());
}
#[test]
fn test_dicom_header_invalid_slice_thickness() {
let mut h = DicomHeader::new_ct("P001", 512, 512, 64);
h.slice_thickness = -1.0;
assert!(h.validate().is_err());
}
#[test]
fn test_window_leveling_apply() {
let wl = WindowLeveling::ct_abdomen(); assert!((wl.apply(-140.0) - 0.0).abs() < 1e-10);
assert!((wl.apply(260.0) - 1.0).abs() < 1e-10);
assert!((wl.apply(60.0) - 0.5).abs() < 1e-10);
}
#[test]
fn test_hounsfield_classify() {
assert_eq!(HounsfieldUnits::classify(-1100.0), Tissue::Air);
assert_eq!(HounsfieldUnits::classify(-700.0), Tissue::Lung);
assert_eq!(HounsfieldUnits::classify(-200.0), Tissue::Fat);
assert_eq!(HounsfieldUnits::classify(0.0), Tissue::SoftTissue);
assert_eq!(HounsfieldUnits::classify(150.0), Tissue::Blood);
assert_eq!(HounsfieldUnits::classify(500.0), Tissue::CancellousBone);
assert_eq!(HounsfieldUnits::classify(1000.0), Tissue::CorticalBone);
assert_eq!(HounsfieldUnits::classify(4000.0), Tissue::Metal);
}
#[test]
fn test_n4_bias_correction_smoke() {
let mut data = Array3::<f64>::ones((4, 4, 4));
for iz in 0..4_usize {
for iy in 0..4_usize {
for ix in 0..4_usize {
data[[iz, iy, ix]] = 100.0 + iz as f64 * 10.0 + iy as f64 * 5.0;
}
}
}
let vol = MedicalVolume::new(data, [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]).expect("MedicalVolume::new should succeed with valid data");
let corrector = N4BiasCorrection::new();
let result = corrector.correct(&vol);
assert!(result.is_ok(), "N4 correction failed: {:?}", result.err());
let (corrected, bias) = result.expect("N4 correction result should be Ok after is_ok check");
let _ = corrected;
let _ = bias;
}
#[test]
fn test_volume_stats_ct() {
let data = Array3::<f64>::from_elem((4, 4, 4), 50.0); let vol = MedicalVolume::new(data, [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]).expect("MedicalVolume::new should succeed with uniform CT data");
let stats = VolumeStats::compute_ct(&vol).expect("compute_ct should succeed on valid volume");
assert!((stats.global_mean - 50.0).abs() < 1e-10);
assert!(stats.total_voxels == 64);
}
#[test]
fn test_medical_volume_axial_slice() {
let data = Array3::<f64>::zeros((5, 4, 3));
let vol = MedicalVolume::new(data, [2.0, 1.5, 1.0], [0.0, 0.0, 0.0]).expect("MedicalVolume::new should succeed with zeros volume");
let slice = vol.axial_slice(2).expect("axial_slice(2) should succeed for a 5-slice volume");
assert_eq!(slice.len(), 12); assert!(vol.axial_slice(10).is_err());
}
#[test]
fn test_polynomial_basis() {
let basis = polynomial_basis_3d(2);
assert_eq!(basis.len(), 10);
}
}