use std::collections::HashMap;
use scirs2_core::ndarray::Array3;
use crate::error::{NdimageError, NdimageResult};
fn check_shapes(labels: &[Array3<u32>]) -> NdimageResult<(usize, usize, usize)> {
if labels.is_empty() {
return Err(NdimageError::InvalidInput(
"Atlas segmentation: must provide at least one atlas".to_string(),
));
}
let s = labels[0].shape();
let shape = (s[0], s[1], s[2]);
for (i, lab) in labels.iter().enumerate().skip(1) {
if lab.shape() != labels[0].shape() {
return Err(NdimageError::DimensionError(format!(
"Atlas segmentation: atlas {} has shape {:?}, expected {:?}",
i,
lab.shape(),
labels[0].shape()
)));
}
}
Ok(shape)
}
pub struct MajorityVoting;
impl MajorityVoting {
pub fn fuse(labels: &[Array3<u32>]) -> NdimageResult<Array3<u32>> {
let (nz, ny, nx) = check_shapes(labels)?;
let n_atlases = labels.len();
let mut result = Array3::<u32>::zeros((nz, ny, nx));
for iz in 0..nz {
for iy in 0..ny {
for ix in 0..nx {
let mut counts: HashMap<u32, usize> = HashMap::new();
for a in 0..n_atlases {
let lv = labels[a][[iz, iy, ix]];
*counts.entry(lv).or_insert(0) += 1;
}
let winner = counts
.iter()
.max_by(|a, b| {
a.1.cmp(b.1).then_with(|| b.0.cmp(a.0))
})
.map(|(&lv, _)| lv)
.unwrap_or(0);
result[[iz, iy, ix]] = winner;
}
}
}
Ok(result)
}
pub fn confidence(labels: &[Array3<u32>]) -> NdimageResult<Array3<f64>> {
let fused = Self::fuse(labels)?;
let (nz, ny, nx) = check_shapes(labels)?;
let n_atlases = labels.len() as f64;
let mut conf = Array3::<f64>::zeros((nz, ny, nx));
for iz in 0..nz {
for iy in 0..ny {
for ix in 0..nx {
let winner = fused[[iz, iy, ix]];
let agree = labels.iter().filter(|l| l[[iz, iy, ix]] == winner).count();
conf[[iz, iy, ix]] = agree as f64 / n_atlases;
}
}
}
Ok(conf)
}
}
#[derive(Debug, Clone)]
pub struct StapleConfig {
pub max_iterations: usize,
pub convergence_threshold: f64,
pub init_sensitivity: f64,
pub init_specificity: f64,
}
impl Default for StapleConfig {
fn default() -> Self {
Self {
max_iterations: 20,
convergence_threshold: 1e-5,
init_sensitivity: 0.99,
init_specificity: 0.99,
}
}
}
#[derive(Debug, Clone)]
pub struct RaterPerformance {
pub sensitivity: f64,
pub specificity: f64,
}
#[derive(Debug, Clone)]
pub struct StapleResult {
pub probability: Array3<f64>,
pub label: Array3<u32>,
pub performance: Vec<RaterPerformance>,
pub iterations: usize,
pub converged: bool,
}
pub struct STAPLE {
config: StapleConfig,
}
impl STAPLE {
pub fn new() -> Self {
Self { config: StapleConfig::default() }
}
pub fn with_config(config: StapleConfig) -> Self {
Self { config }
}
pub fn estimate(&self, labels: &[Array3<u32>]) -> NdimageResult<StapleResult> {
let (nz, ny, nx) = check_shapes(labels)?;
let n = nz * ny * nx;
let r = labels.len();
let d: Vec<Vec<u8>> = labels
.iter()
.map(|l| l.iter().map(|&v| if v > 0 { 1u8 } else { 0u8 }).collect())
.collect();
let mut p: Vec<f64> = vec![self.config.init_sensitivity; r]; let mut q: Vec<f64> = vec![self.config.init_specificity; r];
let prior_fg = 0.5_f64;
let mut w: Vec<f64> = (0..n)
.map(|i| d.iter().map(|rater| rater[i] as f64).sum::<f64>() / r as f64)
.collect();
let mut converged = false;
let mut n_iter = 0;
for _iter in 0..self.config.max_iterations {
n_iter += 1;
let sum_w: f64 = w.iter().sum();
let sum_w0: f64 = w.iter().map(|&wi| 1.0 - wi).sum();
let mut new_p = vec![0.0_f64; r];
let mut new_q = vec![0.0_f64; r];
for j in 0..r {
let tp: f64 = (0..n).map(|i| d[j][i] as f64 * w[i]).sum();
new_p[j] = (tp + 1e-10) / (sum_w + 1e-10);
let tn: f64 = (0..n).map(|i| (1.0 - d[j][i] as f64) * (1.0 - w[i])).sum();
new_q[j] = (tn + 1e-10) / (sum_w0 + 1e-10);
new_p[j] = new_p[j].clamp(1e-6, 1.0 - 1e-6);
new_q[j] = new_q[j].clamp(1e-6, 1.0 - 1e-6);
}
let mut max_change = 0.0_f64;
let mut new_w = vec![0.0_f64; n];
for i in 0..n {
let mut ll1 = prior_fg.ln();
let mut ll0 = (1.0 - prior_fg).ln();
for j in 0..r {
if d[j][i] == 1 {
ll1 += new_p[j].ln();
ll0 += (1.0 - new_q[j]).ln();
} else {
ll1 += (1.0 - new_p[j]).ln();
ll0 += new_q[j].ln();
}
}
let max_ll = ll1.max(ll0);
let p1 = (ll1 - max_ll).exp();
let p0 = (ll0 - max_ll).exp();
new_w[i] = p1 / (p1 + p0 + 1e-10);
max_change = max_change.max((new_w[i] - w[i]).abs());
}
let param_change = (0..r)
.map(|j| (new_p[j] - p[j]).abs().max((new_q[j] - q[j]).abs()))
.fold(0.0_f64, f64::max);
p = new_p;
q = new_q;
w = new_w;
if param_change < self.config.convergence_threshold && max_change < self.config.convergence_threshold {
converged = true;
break;
}
}
let mut probability = Array3::<f64>::zeros((nz, ny, nx));
let mut label = Array3::<u32>::zeros((nz, ny, nx));
for iz in 0..nz {
for iy in 0..ny {
for ix in 0..nx {
let idx = iz * ny * nx + iy * nx + ix;
let wi = w[idx];
probability[[iz, iy, ix]] = wi;
label[[iz, iy, ix]] = if wi > 0.5 { 1 } else { 0 };
}
}
}
let performance: Vec<RaterPerformance> = (0..r)
.map(|j| RaterPerformance {
sensitivity: p[j],
specificity: q[j],
})
.collect();
Ok(StapleResult {
probability,
label,
performance,
iterations: n_iter,
converged,
})
}
}
#[derive(Debug, Clone)]
pub struct JlfConfig {
pub patch_radius: usize,
pub alpha: f64,
pub beta: f64,
}
impl Default for JlfConfig {
fn default() -> Self {
Self {
patch_radius: 2,
alpha: 0.1,
beta: 2.0,
}
}
}
#[derive(Debug, Clone)]
pub struct JlfResult {
pub label: Array3<u32>,
pub weight_sum: Array3<f64>,
}
pub struct JointLabelFusion {
config: JlfConfig,
}
impl JointLabelFusion {
pub fn new() -> Self {
Self { config: JlfConfig::default() }
}
pub fn with_config(config: JlfConfig) -> Self {
Self { config }
}
pub fn fuse(
&self,
target: &Array3<f64>,
atlas_images: &[Array3<f64>],
atlas_labels: &[Array3<u32>],
) -> NdimageResult<JlfResult> {
if atlas_images.len() != atlas_labels.len() {
return Err(NdimageError::InvalidInput(
"JointLabelFusion: atlas_images and atlas_labels must have equal length".to_string(),
));
}
let n_atlases = atlas_images.len();
if n_atlases == 0 {
return Err(NdimageError::InvalidInput(
"JointLabelFusion: must provide at least one atlas".to_string(),
));
}
let ts = target.shape();
for (i, ai) in atlas_images.iter().enumerate() {
if ai.shape() != ts {
return Err(NdimageError::DimensionError(format!(
"JointLabelFusion: atlas_images[{}] shape {:?} ≠ target shape {:?}",
i,
ai.shape(),
ts
)));
}
}
for (i, al) in atlas_labels.iter().enumerate() {
if al.shape() != ts {
return Err(NdimageError::DimensionError(format!(
"JointLabelFusion: atlas_labels[{}] shape {:?} ≠ target shape {:?}",
i,
al.shape(),
ts
)));
}
}
let (nz, ny, nx) = (ts[0], ts[1], ts[2]);
let pr = self.config.patch_radius as isize;
let mut label_votes: HashMap<u32, Array3<f64>> = HashMap::new();
let mut weight_sum = Array3::<f64>::zeros((nz, ny, nx));
for iz in 0..nz {
for iy in 0..ny {
for ix in 0..nx {
let t_patch = extract_patch_3d(target, iz as isize, iy as isize, ix as isize, pr);
let mut weights = Vec::with_capacity(n_atlases);
for a in 0..n_atlases {
let a_patch = extract_patch_3d(
&atlas_images[a],
iz as isize,
iy as isize,
ix as isize,
pr,
);
let w = self.patch_weight(&t_patch, &a_patch);
weights.push(w);
}
let w_sum: f64 = weights.iter().sum();
let w_norm: Vec<f64> = if w_sum > 1e-12 {
weights.iter().map(|&w| w / w_sum).collect()
} else {
vec![1.0 / n_atlases as f64; n_atlases]
};
let total_w: f64 = w_norm.iter().sum();
weight_sum[[iz, iy, ix]] = total_w;
for (a, &wn) in w_norm.iter().enumerate() {
let lv = atlas_labels[a][[iz, iy, ix]];
label_votes.entry(lv).or_insert_with(|| Array3::<f64>::zeros((nz, ny, nx)))
[[iz, iy, ix]] += wn;
}
}
}
}
let mut label_result = Array3::<u32>::zeros((nz, ny, nx));
for iz in 0..nz {
for iy in 0..ny {
for ix in 0..nx {
let winner_label = label_votes
.iter()
.max_by(|a, b| {
a.1[[iz, iy, ix]]
.partial_cmp(&b.1[[iz, iy, ix]])
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| b.0.cmp(a.0))
})
.map(|(&lv, _)| lv)
.unwrap_or(0);
label_result[[iz, iy, ix]] = winner_label;
}
}
}
Ok(JlfResult { label: label_result, weight_sum })
}
fn patch_weight(&self, target_patch: &[f64], atlas_patch: &[f64]) -> f64 {
if target_patch.is_empty() {
return 1.0;
}
let n = target_patch.len().min(atlas_patch.len()) as f64;
let ssd: f64 = target_patch
.iter()
.zip(atlas_patch.iter())
.map(|(t, a)| (t - a).powi(2))
.sum();
let normalised_ssd = ssd / (n * self.config.beta + 1e-10);
(-normalised_ssd / (self.config.alpha + 1e-10)).exp()
}
}
fn extract_patch_3d(
vol: &Array3<f64>,
iz: isize,
iy: isize,
ix: isize,
pr: isize,
) -> Vec<f64> {
let shape = vol.shape();
let (nz, ny, nx) = (shape[0] as isize, shape[1] as isize, shape[2] as isize);
let mut patch = Vec::with_capacity(((2 * pr + 1) as usize).pow(3));
for dz in -pr..=pr {
for dy in -pr..=pr {
for dx in -pr..=pr {
let z = (iz + dz).clamp(0, nz - 1) as usize;
let y = (iy + dy).clamp(0, ny - 1) as usize;
let x = (ix + dx).clamp(0, nx - 1) as usize;
patch.push(vol[[z, y, x]]);
}
}
}
patch
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionMethod {
MajorityVoting,
Staple,
JointLabelFusion,
}
#[derive(Debug, Clone)]
pub struct AtlasConfig {
pub fusion_method: FusionMethod,
pub staple_config: StapleConfig,
pub jlf_config: JlfConfig,
}
impl Default for AtlasConfig {
fn default() -> Self {
Self {
fusion_method: FusionMethod::MajorityVoting,
staple_config: StapleConfig::default(),
jlf_config: JlfConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct AtlasSegmentationResult {
pub label: Array3<u32>,
pub n_atlases: usize,
pub fusion_method: FusionMethod,
pub staple_result: Option<StapleResult>,
}
pub struct AtlasSegmentation {
config: AtlasConfig,
}
impl AtlasSegmentation {
pub fn new() -> Self {
Self { config: AtlasConfig::default() }
}
pub fn with_config(config: AtlasConfig) -> Self {
Self { config }
}
pub fn segment(
&self,
atlas_labels: &[Array3<u32>],
target_image: Option<&Array3<f64>>,
atlas_images: Option<&[Array3<f64>]>,
) -> NdimageResult<AtlasSegmentationResult> {
let n_atlases = atlas_labels.len();
match self.config.fusion_method {
FusionMethod::MajorityVoting => {
let label = MajorityVoting::fuse(atlas_labels)?;
Ok(AtlasSegmentationResult {
label,
n_atlases,
fusion_method: FusionMethod::MajorityVoting,
staple_result: None,
})
}
FusionMethod::Staple => {
let staple = STAPLE::with_config(self.config.staple_config.clone());
let sr = staple.estimate(atlas_labels)?;
let label = sr.label.clone();
Ok(AtlasSegmentationResult {
label,
n_atlases,
fusion_method: FusionMethod::Staple,
staple_result: Some(sr),
})
}
FusionMethod::JointLabelFusion => {
let target = target_image.ok_or_else(|| {
NdimageError::InvalidInput(
"AtlasSegmentation: JLF requires target_image".to_string(),
)
})?;
let imgs = atlas_images.ok_or_else(|| {
NdimageError::InvalidInput(
"AtlasSegmentation: JLF requires atlas_images".to_string(),
)
})?;
let jlf = JointLabelFusion::with_config(self.config.jlf_config.clone());
let jr = jlf.fuse(target, imgs, atlas_labels)?;
Ok(AtlasSegmentationResult {
label: jr.label,
n_atlases,
fusion_method: FusionMethod::JointLabelFusion,
staple_result: None,
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array3;
fn sphere_labels(nz: usize, ny: usize, nx: usize, label: u32) -> Array3<u32> {
let mut a = Array3::<u32>::zeros((nz, ny, nx));
let cz = nz as f64 / 2.0;
let cy = ny as f64 / 2.0;
let cx = nx as f64 / 2.0;
let r2 = ((nz.min(ny).min(nx)) as f64 / 3.0).powi(2);
for iz in 0..nz {
for iy in 0..ny {
for ix in 0..nx {
let d2 = (iz as f64 - cz).powi(2)
+ (iy as f64 - cy).powi(2)
+ (ix as f64 - cx).powi(2);
if d2 < r2 {
a[[iz, iy, ix]] = label;
}
}
}
}
a
}
#[test]
fn test_majority_voting_identical_atlases() {
let a = sphere_labels(8, 8, 8, 1);
let labels = vec![a.clone(), a.clone(), a.clone()];
let fused = MajorityVoting::fuse(&labels).expect("MajorityVoting::fuse should succeed with identical atlases");
for iz in 0..8 {
for iy in 0..8 {
for ix in 0..8 {
assert_eq!(fused[[iz, iy, ix]], a[[iz, iy, ix]]);
}
}
}
}
#[test]
fn test_majority_voting_confidence_perfect() {
let a = sphere_labels(6, 6, 6, 1);
let labels = vec![a.clone(), a.clone()];
let conf = MajorityVoting::confidence(&labels).expect("MajorityVoting::confidence should succeed with identical atlases");
for v in conf.iter() {
assert!((*v - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_majority_voting_tie_breaks() {
let a = sphere_labels(4, 4, 4, 1);
let b = sphere_labels(4, 4, 4, 2);
let labels = vec![a, b];
let fused = MajorityVoting::fuse(&labels).expect("MajorityVoting::fuse should succeed with two atlases");
for v in fused.iter() {
assert!(*v == 1 || *v == 0, "unexpected label {}", v);
}
}
#[test]
fn test_staple_smoke() {
let a = sphere_labels(4, 4, 4, 1);
let b = sphere_labels(4, 4, 4, 1);
let labels = vec![a, b];
let staple = STAPLE::new();
let result = staple.estimate(&labels).expect("STAPLE::estimate should succeed with valid atlases");
assert_eq!(result.performance.len(), 2);
for perf in &result.performance {
assert!(
perf.sensitivity > 0.5,
"Expected high sensitivity, got {}",
perf.sensitivity
);
}
}
#[test]
fn test_staple_single_atlas() {
let a = sphere_labels(4, 4, 4, 1);
let labels = vec![a];
let result = STAPLE::new().estimate(&labels).expect("STAPLE::estimate should succeed with single atlas");
assert_eq!(result.performance.len(), 1);
}
#[test]
fn test_jlf_smoke() {
let n = 6;
let target = Array3::<f64>::from_elem((n, n, n), 100.0);
let atlas_img = Array3::<f64>::from_elem((n, n, n), 100.0);
let atlas_label = sphere_labels(n, n, n, 1);
let jlf = JointLabelFusion::new();
let result = jlf.fuse(&target, &[atlas_img], &[atlas_label.clone()]).expect("JLF::fuse should succeed with single identical atlas");
for iz in 0..n {
for iy in 0..n {
for ix in 0..n {
assert_eq!(result.label[[iz, iy, ix]], atlas_label[[iz, iy, ix]]);
}
}
}
}
#[test]
fn test_atlas_segmentation_majority_voting() {
let a = sphere_labels(6, 6, 6, 1);
let labels = vec![a.clone(), a.clone()];
let seg = AtlasSegmentation::new();
let result = seg.segment(&labels, None, None).expect("AtlasSegmentation::segment should succeed with valid atlases");
assert_eq!(result.fusion_method, FusionMethod::MajorityVoting);
assert_eq!(result.n_atlases, 2);
}
#[test]
fn test_atlas_segmentation_staple() {
let a = sphere_labels(4, 4, 4, 1);
let labels = vec![a.clone(), a.clone()];
let config = AtlasConfig {
fusion_method: FusionMethod::Staple,
..Default::default()
};
let seg = AtlasSegmentation::with_config(config);
let result = seg.segment(&labels, None, None).expect("AtlasSegmentation STAPLE should succeed with valid atlases");
assert_eq!(result.fusion_method, FusionMethod::Staple);
assert!(result.staple_result.is_some());
}
}