use crate::error::{Result, VisionError};
use scirs2_core::ndarray::{Array2, Array3};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FCNBackbone {
VGGLite,
ResNet18,
MobileNetV2,
}
#[derive(Debug, Clone)]
pub struct FCNConfig {
pub num_classes: usize,
pub input_channels: usize,
pub backbone: FCNBackbone,
pub stride: usize,
pub apply_softmax: bool,
}
impl Default for FCNConfig {
fn default() -> Self {
Self {
num_classes: 21,
input_channels: 3,
backbone: FCNBackbone::ResNet18,
stride: 32,
apply_softmax: true,
}
}
}
#[derive(Debug, Clone)]
pub struct FCNOutput {
pub logits: Array3<f32>,
pub class_map: Array2<usize>,
pub num_classes: usize,
}
impl FCNOutput {
pub fn from_logits(logits: Array3<f32>) -> Result<Self> {
let (h, w, c) = logits.dim();
if c == 0 {
return Err(VisionError::InvalidParameter(
"logits must have at least one class channel".into(),
));
}
let mut class_map = Array2::zeros((h, w));
for y in 0..h {
for x in 0..w {
let mut best_c = 0usize;
let mut best_v = logits[[y, x, 0]];
for ci in 1..c {
let v = logits[[y, x, ci]];
if v > best_v {
best_v = v;
best_c = ci;
}
}
class_map[[y, x]] = best_c;
}
}
Ok(Self {
logits,
class_map,
num_classes: c,
})
}
pub fn with_softmax(mut self) -> Self {
let (h, w, c) = self.logits.dim();
for y in 0..h {
for x in 0..w {
let max_val = (0..c)
.map(|ci| self.logits[[y, x, ci]])
.fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for ci in 0..c {
let v = (self.logits[[y, x, ci]] - max_val).exp();
self.logits[[y, x, ci]] = v;
sum += v;
}
if sum > 0.0 {
for ci in 0..c {
self.logits[[y, x, ci]] /= sum;
}
}
}
}
self
}
}
pub fn bilinear_upsample_mask(
mask: &Array3<f32>,
out_h: usize,
out_w: usize,
) -> Result<Array3<f32>> {
let (in_h, in_w, c) = mask.dim();
if out_h == 0 || out_w == 0 {
return Err(VisionError::InvalidParameter(
"bilinear_upsample_mask: output dimensions must be > 0".into(),
));
}
if in_h == 0 || in_w == 0 {
return Err(VisionError::InvalidParameter(
"bilinear_upsample_mask: input dimensions must be > 0".into(),
));
}
let mut out = Array3::<f32>::zeros((out_h, out_w, c));
let scale_y = in_h as f32 / out_h as f32;
let scale_x = in_w as f32 / out_w as f32;
for oy in 0..out_h {
let src_y = (oy as f32 + 0.5) * scale_y - 0.5;
let y0 = (src_y.floor() as isize).max(0) as usize;
let y1 = (y0 + 1).min(in_h - 1);
let dy = src_y - src_y.floor();
for ox in 0..out_w {
let src_x = (ox as f32 + 0.5) * scale_x - 0.5;
let x0 = (src_x.floor() as isize).max(0) as usize;
let x1 = (x0 + 1).min(in_w - 1);
let dx = src_x - src_x.floor();
let w00 = (1.0 - dy) * (1.0 - dx);
let w01 = (1.0 - dy) * dx;
let w10 = dy * (1.0 - dx);
let w11 = dy * dx;
for ci in 0..c {
out[[oy, ox, ci]] = w00 * mask[[y0, x0, ci]]
+ w01 * mask[[y0, x1, ci]]
+ w10 * mask[[y1, x0, ci]]
+ w11 * mask[[y1, x1, ci]];
}
}
}
Ok(out)
}
#[derive(Debug, Clone)]
pub struct SegmentationMetrics {
pub per_class_iou: Vec<f64>,
pub mean_iou: f64,
pub pixel_accuracy: f64,
pub per_class_dice: Vec<f64>,
pub mean_dice: f64,
}
pub fn compute_segmentation_metrics(
pred: &Array2<usize>,
gt: &Array2<usize>,
num_classes: usize,
ignore_index: Option<usize>,
) -> Result<SegmentationMetrics> {
let (h, w) = pred.dim();
if pred.dim() != gt.dim() {
return Err(VisionError::InvalidParameter(
"compute_segmentation_metrics: pred and gt must have the same shape".into(),
));
}
if num_classes == 0 {
return Err(VisionError::InvalidParameter(
"compute_segmentation_metrics: num_classes must be > 0".into(),
));
}
let mut conf = vec![0u64; num_classes * num_classes];
let mut total_valid = 0u64;
let mut correct = 0u64;
for y in 0..h {
for x in 0..w {
let gt_c = gt[[y, x]];
let pr_c = pred[[y, x]];
if let Some(ig) = ignore_index {
if gt_c == ig {
continue;
}
}
let safe_pr = pr_c.min(num_classes - 1);
let safe_gt = gt_c.min(num_classes - 1);
conf[safe_gt * num_classes + safe_pr] += 1;
total_valid += 1;
if safe_pr == safe_gt {
correct += 1;
}
}
}
let pixel_accuracy = if total_valid > 0 {
correct as f64 / total_valid as f64
} else {
1.0
};
let mut per_class_iou = vec![0.0f64; num_classes];
let mut per_class_dice = vec![0.0f64; num_classes];
let mut valid_count = 0usize;
let mut iou_sum = 0.0f64;
let mut dice_sum = 0.0f64;
for c in 0..num_classes {
let tp = conf[c * num_classes + c] as f64;
let fp: f64 = (0..num_classes)
.filter(|&r| r != c)
.map(|r| conf[r * num_classes + c] as f64)
.sum();
let fn_: f64 = (0..num_classes)
.filter(|&p| p != c)
.map(|p| conf[c * num_classes + p] as f64)
.sum();
let denom_iou = tp + fp + fn_;
let denom_dice = 2.0 * tp + fp + fn_;
if denom_iou > 0.0 {
let iou = tp / denom_iou;
per_class_iou[c] = iou;
iou_sum += iou;
valid_count += 1;
}
if denom_dice > 0.0 {
per_class_dice[c] = 2.0 * tp / denom_dice;
}
dice_sum += per_class_dice[c];
}
let mean_iou = if valid_count > 0 {
iou_sum / valid_count as f64
} else {
0.0
};
let mean_dice = if num_classes > 0 {
dice_sum / num_classes as f64
} else {
0.0
};
Ok(SegmentationMetrics {
per_class_iou,
mean_iou,
pixel_accuracy,
per_class_dice,
mean_dice,
})
}
#[derive(Debug, Clone)]
pub struct DenseCRFParams {
pub iterations: usize,
pub spatial_sigma: f32,
pub appearance_sigma: Option<f32>,
pub spatial_weight: f32,
pub appearance_weight: f32,
}
impl Default for DenseCRFParams {
fn default() -> Self {
Self {
iterations: 5,
spatial_sigma: 3.0,
appearance_sigma: Some(10.0),
spatial_weight: 3.0,
appearance_weight: 10.0,
}
}
}
pub fn dense_crf_post_process(
prob_map: &Array3<f32>,
rgb_image: Option<&Array3<f32>>,
params: &DenseCRFParams,
) -> Result<Array3<f32>> {
let (h, w, c) = prob_map.dim();
if c == 0 {
return Err(VisionError::InvalidParameter(
"dense_crf_post_process: num_classes must be > 0".into(),
));
}
if let Some(rgb) = rgb_image {
let (rh, rw, rc) = rgb.dim();
if rh != h || rw != w {
return Err(VisionError::InvalidParameter(
"dense_crf_post_process: rgb_image spatial dimensions must match prob_map".into(),
));
}
if rc != 3 {
return Err(VisionError::InvalidParameter(
"dense_crf_post_process: rgb_image must have 3 channels".into(),
));
}
}
let mut q = prob_map.clone();
let sigma_s = params.spatial_sigma.max(0.1);
let radius_s = (3.0 * sigma_s).ceil() as isize;
let spatial_kernel: Vec<f32> = (-radius_s..=radius_s)
.map(|d| (-(d as f32 * d as f32) / (2.0 * sigma_s * sigma_s)).exp())
.collect();
let kernel_sum: f32 = spatial_kernel.iter().sum();
let spatial_kernel: Vec<f32> = spatial_kernel.iter().map(|v| v / kernel_sum).collect();
for _iter in 0..params.iterations {
let mut msg_spatial = Array3::<f32>::zeros((h, w, c));
let mut tmp = Array3::<f32>::zeros((h, w, c));
for y in 0..h {
for x in 0..w {
for (ki, &kv) in spatial_kernel.iter().enumerate() {
let dx = ki as isize - radius_s;
let nx = (x as isize + dx).clamp(0, w as isize - 1) as usize;
for ci in 0..c {
tmp[[y, x, ci]] += kv * q[[y, nx, ci]];
}
}
}
}
for y in 0..h {
for x in 0..w {
for (ki, &kv) in spatial_kernel.iter().enumerate() {
let dy = ki as isize - radius_s;
let ny = (y as isize + dy).clamp(0, h as isize - 1) as usize;
for ci in 0..c {
msg_spatial[[y, x, ci]] += kv * tmp[[ny, x, ci]];
}
}
}
}
let msg_appearance =
if let (Some(rgb), Some(sigma_a)) = (rgb_image, params.appearance_sigma) {
let mut app = Array3::<f32>::zeros((h, w, c));
let sigma_a2 = 2.0 * sigma_a * sigma_a;
let r = 3usize;
for y in 0..h {
for x in 0..w {
let mut w_sum = 0.0f32;
let mut acc = vec![0.0f32; c];
let y0 = y.saturating_sub(r);
let y1 = (y + r + 1).min(h);
let x0 = x.saturating_sub(r);
let x1 = (x + r + 1).min(w);
for ny in y0..y1 {
for nx in x0..x1 {
let dr = rgb[[y, x, 0]] - rgb[[ny, nx, 0]];
let dg = rgb[[y, x, 1]] - rgb[[ny, nx, 1]];
let db = rgb[[y, x, 2]] - rgb[[ny, nx, 2]];
let colour_dist2 = dr * dr + dg * dg + db * db;
let dy = (y as f32 - ny as f32).powi(2);
let dx2 = (x as f32 - nx as f32).powi(2);
let spatial_dist2 = dy + dx2;
let w = (-colour_dist2 / sigma_a2
- spatial_dist2 / (2.0 * sigma_s * sigma_s))
.exp();
for ci in 0..c {
acc[ci] += w * q[[ny, nx, ci]];
}
w_sum += w;
}
}
if w_sum > 0.0 {
for ci in 0..c {
app[[y, x, ci]] = acc[ci] / w_sum;
}
}
}
}
Some(app)
} else {
None
};
let mut new_q = Array3::<f32>::zeros((h, w, c));
for y in 0..h {
for x in 0..w {
for ci in 0..c {
let spatial_msg = msg_spatial[[y, x, ci]];
let app_msg = msg_appearance
.as_ref()
.map(|a| a[[y, x, ci]])
.unwrap_or(0.0);
let combined =
params.spatial_weight * spatial_msg + params.appearance_weight * app_msg;
let compat = combined - q[[y, x, ci]];
new_q[[y, x, ci]] = prob_map[[y, x, ci]] * (-compat).exp();
}
let sum: f32 = (0..c).map(|ci| new_q[[y, x, ci]]).sum();
if sum > 1e-8 {
for ci in 0..c {
new_q[[y, x, ci]] /= sum;
}
} else {
let uniform = 1.0 / c as f32;
for ci in 0..c {
new_q[[y, x, ci]] = uniform;
}
}
}
}
q = new_q;
}
Ok(q)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_uniform_probs(h: usize, w: usize, c: usize) -> Array3<f32> {
Array3::from_elem((h, w, c), 1.0 / c as f32)
}
#[test]
fn test_fcn_config_default() {
let cfg = FCNConfig::default();
assert_eq!(cfg.num_classes, 21);
assert_eq!(cfg.input_channels, 3);
assert_eq!(cfg.stride, 32);
}
#[test]
fn test_fcnoutput_from_logits() {
let logits = Array3::from_shape_fn((4, 4, 3), |(y, x, c)| {
if c == 0 {
1.0
} else if y > 1 && c == 1 {
2.0
} else {
0.0
}
});
let out = FCNOutput::from_logits(logits).expect("FCNOutput::from_logits failed");
assert_eq!(out.class_map[[0, 0]], 0);
assert_eq!(out.class_map[[2, 0]], 1);
}
#[test]
fn test_fcnoutput_softmax() {
let logits = Array3::from_elem((2, 2, 2), 1.0f32);
let out = FCNOutput::from_logits(logits)
.expect("from_logits")
.with_softmax();
for y in 0..2 {
for x in 0..2 {
let s: f32 = (0..2).map(|c| out.logits[[y, x, c]]).sum();
assert!((s - 1.0).abs() < 1e-5, "softmax sum={}", s);
}
}
}
#[test]
fn test_bilinear_upsample_trivial() {
let mask = Array3::from_elem((4, 4, 3), 0.5f32);
let up = bilinear_upsample_mask(&mask, 8, 8).expect("upsample failed");
assert_eq!(up.dim(), (8, 8, 3));
for v in up.iter() {
assert!((*v - 0.5).abs() < 1e-4, "unexpected value {}", v);
}
}
#[test]
fn test_bilinear_upsample_noop() {
let mask = Array3::from_shape_fn((3, 3, 2), |(y, x, c)| (y + x + c) as f32 * 0.1);
let up = bilinear_upsample_mask(&mask, 3, 3).expect("upsample failed");
assert_eq!(up.dim(), (3, 3, 2));
}
#[test]
fn test_bilinear_upsample_invalid() {
let mask = Array3::from_elem((4, 4, 2), 0.5f32);
let res = bilinear_upsample_mask(&mask, 0, 8);
assert!(res.is_err());
}
#[test]
fn test_metrics_perfect_prediction() {
let gt = Array2::from_shape_fn((4, 4), |(y, x)| if y < 2 { 0 } else { 1 });
let metrics =
compute_segmentation_metrics(>.clone(), >, 2, None).expect("metrics failed");
assert!((metrics.mean_iou - 1.0).abs() < 1e-9, "mIoU should be 1.0");
assert!((metrics.pixel_accuracy - 1.0).abs() < 1e-9);
}
#[test]
fn test_metrics_all_wrong() {
let gt = Array2::from_shape_fn((4, 4), |(_, _)| 0usize);
let pred = Array2::from_shape_fn((4, 4), |(_, _)| 1usize);
let metrics = compute_segmentation_metrics(&pred, >, 2, None).expect("metrics failed");
assert!(
(metrics.pixel_accuracy).abs() < 1e-9,
"accuracy should be 0.0"
);
assert!((metrics.per_class_iou[0]).abs() < 1e-9);
}
#[test]
fn test_metrics_ignore_index() {
let mut gt = Array2::from_shape_fn((4, 4), |(_, _)| 0usize);
gt[[0, 0]] = 255; let pred = Array2::zeros((4, 4));
let metrics =
compute_segmentation_metrics(&pred, >, 2, Some(255)).expect("metrics failed");
assert!((metrics.pixel_accuracy - 1.0).abs() < 1e-9);
}
#[test]
fn test_dense_crf_shape_preserved() {
let prob = make_uniform_probs(8, 8, 4);
let params = DenseCRFParams {
iterations: 2,
..Default::default()
};
let refined = dense_crf_post_process(&prob, None, ¶ms).expect("crf failed");
assert_eq!(refined.dim(), (8, 8, 4));
}
#[test]
fn test_dense_crf_rows_sum_to_one() {
let prob = make_uniform_probs(6, 6, 3);
let params = DenseCRFParams {
iterations: 3,
..Default::default()
};
let refined = dense_crf_post_process(&prob, None, ¶ms).expect("crf failed");
for y in 0..6 {
for x in 0..6 {
let s: f32 = (0..3).map(|c| refined[[y, x, c]]).sum();
assert!((s - 1.0).abs() < 1e-4, "prob sum={} at ({},{})", s, y, x);
}
}
}
#[test]
fn test_dense_crf_with_rgb() {
let prob = make_uniform_probs(6, 6, 2);
let rgb = Array3::<f32>::from_elem((6, 6, 3), 128.0);
let params = DenseCRFParams {
iterations: 2,
appearance_sigma: Some(10.0),
..Default::default()
};
let refined =
dense_crf_post_process(&prob, Some(&rgb), ¶ms).expect("crf with rgb failed");
assert_eq!(refined.dim(), (6, 6, 2));
}
#[test]
fn test_dense_crf_invalid_rgb_channels() {
let prob = make_uniform_probs(4, 4, 2);
let bad_rgb = Array3::<f32>::zeros((4, 4, 1));
let params = DenseCRFParams::default();
let res = dense_crf_post_process(&prob, Some(&bad_rgb), ¶ms);
assert!(res.is_err());
}
}