#[derive(Debug, Clone)]
pub struct CrfConfig {
pub n_iter: usize,
pub spatial_weight: f64,
pub bilateral_weight: f64,
pub spatial_sigma: f64,
pub color_sigma: f64,
pub bilateral_sigma_pos: f64,
pub n_labels: usize,
}
impl Default for CrfConfig {
fn default() -> Self {
Self {
n_iter: 5,
spatial_weight: 3.0,
bilateral_weight: 10.0,
spatial_sigma: 3.0,
color_sigma: 8.0,
bilateral_sigma_pos: 5.0,
n_labels: 2,
}
}
}
#[derive(Debug, Clone)]
pub struct DenseCrf {
config: CrfConfig,
unary: Vec<f64>,
image: Vec<[f64; 3]>,
width: usize,
height: usize,
}
impl DenseCrf {
pub fn new(config: CrfConfig) -> Self {
Self {
config,
unary: Vec::new(),
image: Vec::new(),
width: 0,
height: 0,
}
}
pub fn set_unary(mut self, unary: Vec<Vec<f64>>) -> Self {
let n_labels = self.config.n_labels;
self.unary = unary
.into_iter()
.flat_map(|row| {
let mut r = row;
r.resize(n_labels, 0.0);
r
})
.collect();
self
}
pub fn set_image_2d(mut self, image: &[Vec<[f64; 3]>]) -> Self {
self.height = image.len();
self.width = if self.height > 0 { image[0].len() } else { 0 };
self.image = image.iter().flat_map(|row| row.iter().cloned()).collect();
self
}
pub fn infer(&self) -> Vec<usize> {
let n_pixels = self.unary.len() / self.config.n_labels.max(1);
let n_labels = self.config.n_labels;
if n_pixels == 0 || n_labels == 0 {
return Vec::new();
}
let mut q = init_q(&self.unary, n_pixels, n_labels);
for _ in 0..self.config.n_iter {
let mut msg_spatial = gaussian_filter_2d_per_label(
&q,
self.height,
self.width,
n_labels,
self.config.spatial_sigma,
);
let msg_bilateral = bilateral_message_2d(
&q,
&self.image,
self.height,
self.width,
n_labels,
self.config.bilateral_sigma_pos,
self.config.color_sigma,
);
for i in 0..n_pixels * n_labels {
msg_spatial[i] = self.config.spatial_weight * msg_spatial[i]
+ self.config.bilateral_weight * msg_bilateral[i];
}
let mut row_sums = vec![0.0f64; n_pixels];
for i in 0..n_pixels {
for l in 0..n_labels {
row_sums[i] += msg_spatial[i * n_labels + l];
}
}
let mut compat = vec![0.0f64; n_pixels * n_labels];
for i in 0..n_pixels {
for l in 0..n_labels {
compat[i * n_labels + l] = row_sums[i] - msg_spatial[i * n_labels + l];
}
}
for i in 0..n_pixels {
let base = i * n_labels;
let mut log_p: Vec<f64> = (0..n_labels)
.map(|l| -(self.unary[base + l] + compat[base + l]))
.collect();
let max_lp = log_p.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut z = 0.0f64;
for lp in log_p.iter_mut() {
*lp = (*lp - max_lp).exp();
z += *lp;
}
if z < 1e-20 {
z = 1.0;
}
for l in 0..n_labels {
q[base + l] = log_p[l] / z;
}
}
}
(0..n_pixels)
.map(|i| {
let base = i * n_labels;
let mut best_l = 0usize;
let mut best_q = q[base];
for l in 1..n_labels {
if q[base + l] > best_q {
best_q = q[base + l];
best_l = l;
}
}
best_l
})
.collect()
}
}
pub fn apply_to_segmentation_2d(
unary_log_prob: &[Vec<Vec<f64>>],
image: &[Vec<[f64; 3]>],
config: &CrfConfig,
) -> Vec<Vec<usize>> {
let rows = unary_log_prob.len();
if rows == 0 {
return Vec::new();
}
let cols = unary_log_prob[0].len();
let unary_flat: Vec<Vec<f64>> = unary_log_prob
.iter()
.flat_map(|row| row.iter().cloned())
.collect();
let crf = DenseCrf::new(config.clone())
.set_unary(unary_flat)
.set_image_2d(image);
let flat_labels = crf.infer();
let mut result = vec![vec![0usize; cols]; rows];
for (idx, &lbl) in flat_labels.iter().enumerate() {
let r = idx / cols;
let c = idx % cols;
if r < rows && c < cols {
result[r][c] = lbl;
}
}
result
}
fn init_q(unary: &[f64], n_pixels: usize, n_labels: usize) -> Vec<f64> {
let mut q = vec![0.0f64; n_pixels * n_labels];
for i in 0..n_pixels {
let base = i * n_labels;
let mut vals: Vec<f64> = (0..n_labels).map(|l| -unary[base + l]).collect();
let max_v = vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut z = 0.0;
for v in vals.iter_mut() {
*v = (*v - max_v).exp();
z += *v;
}
if z < 1e-20 {
z = 1.0;
}
for l in 0..n_labels {
q[base + l] = vals[l] / z;
}
}
q
}
fn gaussian_filter_2d_per_label(
q: &[f64],
height: usize,
width: usize,
n_labels: usize,
sigma: f64,
) -> Vec<f64> {
let kernel = gaussian_kernel_1d(sigma);
let k_rad = kernel.len() / 2;
let mut out = vec![0.0f64; height * width * n_labels];
let mut tmp = vec![0.0f64; height * width * n_labels];
for r in 0..height {
for c in 0..width {
let pixel_out = r * width + c;
for l in 0..n_labels {
let mut acc = 0.0f64;
let mut wt = 0.0f64;
for (ki, &kv) in kernel.iter().enumerate() {
let ci = c as isize + ki as isize - k_rad as isize;
if ci >= 0 && ci < width as isize {
let pixel_in = r * width + ci as usize;
acc += kv * q[pixel_in * n_labels + l];
wt += kv;
}
}
tmp[pixel_out * n_labels + l] = if wt > 0.0 { acc / wt } else { 0.0 };
}
}
}
for r in 0..height {
for c in 0..width {
let pixel_out = r * width + c;
for l in 0..n_labels {
let mut acc = 0.0f64;
let mut wt = 0.0f64;
for (ki, &kv) in kernel.iter().enumerate() {
let ri = r as isize + ki as isize - k_rad as isize;
if ri >= 0 && ri < height as isize {
let pixel_in = ri as usize * width + c;
acc += kv * tmp[pixel_in * n_labels + l];
wt += kv;
}
}
out[pixel_out * n_labels + l] = if wt > 0.0 { acc / wt } else { 0.0 };
}
}
}
out
}
fn bilateral_message_2d(
q: &[f64],
image: &[[f64; 3]],
height: usize,
width: usize,
n_labels: usize,
sigma_pos: f64,
sigma_col: f64,
) -> Vec<f64> {
let mut out = vec![0.0f64; height * width * n_labels];
if image.is_empty() {
return gaussian_filter_2d_per_label(q, height, width, n_labels, sigma_pos);
}
let radius = ((3.0 * sigma_pos).ceil() as usize).max(1);
let inv_2sig2_pos = 0.5 / (sigma_pos * sigma_pos);
let inv_2sig2_col = 0.5 / (sigma_col * sigma_col);
for r in 0..height {
for c in 0..width {
let pi = r * width + c;
let ii = if pi < image.len() {
image[pi]
} else {
[0.0; 3]
};
let r_lo = r.saturating_sub(radius);
let r_hi = (r + radius + 1).min(height);
let c_lo = c.saturating_sub(radius);
let c_hi = (c + width).min(c + radius + 1).min(width);
let mut acc = vec![0.0f64; n_labels];
let mut total_w = 0.0f64;
for rj in r_lo..r_hi {
for cj in c_lo..c_hi {
let pj = rj * width + cj;
let ij = if pj < image.len() {
image[pj]
} else {
[0.0; 3]
};
let dr = (r as f64) - (rj as f64);
let dc = (c as f64) - (cj as f64);
let dist2_pos = dr * dr + dc * dc;
let dcol0 = ii[0] - ij[0];
let dcol1 = ii[1] - ij[1];
let dcol2 = ii[2] - ij[2];
let dist2_col = dcol0 * dcol0 + dcol1 * dcol1 + dcol2 * dcol2;
let w = (-inv_2sig2_pos * dist2_pos - inv_2sig2_col * dist2_col).exp();
total_w += w;
for l in 0..n_labels {
acc[l] += w * q[pj * n_labels + l];
}
}
}
if total_w > 1e-20 {
for l in 0..n_labels {
out[pi * n_labels + l] = acc[l] / total_w;
}
}
}
}
out
}
fn gaussian_kernel_1d(sigma: f64) -> Vec<f64> {
let radius = ((3.0 * sigma).ceil() as usize).max(1);
let mut k: Vec<f64> = (0..=2 * radius)
.map(|i| {
let x = i as f64 - radius as f64;
(-0.5 * x * x / (sigma * sigma)).exp()
})
.collect();
let sum: f64 = k.iter().sum();
if sum > 0.0 {
for v in k.iter_mut() {
*v /= sum;
}
}
k
}
#[cfg(test)]
mod tests {
use super::*;
fn make_unary_2labels(n_pixels: usize) -> Vec<Vec<f64>> {
(0..n_pixels)
.map(|i| {
if i < n_pixels / 2 {
vec![0.1, 5.0] } else {
vec![5.0, 0.1] }
})
.collect()
}
#[test]
fn test_infer_labels_in_valid_range() {
let config = CrfConfig {
n_labels: 2,
n_iter: 3,
..Default::default()
};
let n_pixels = 4;
let unary = make_unary_2labels(n_pixels);
let image: Vec<Vec<[f64; 3]>> = vec![
vec![[0.0; 3], [0.0; 3]],
vec![[255.0, 0.0, 0.0], [255.0, 0.0, 0.0]],
];
let crf = DenseCrf::new(config).set_unary(unary).set_image_2d(&image);
let labels = crf.infer();
assert_eq!(labels.len(), n_pixels);
for &l in &labels {
assert!(l < 2, "label {} out of range [0,2)", l);
}
}
#[test]
fn test_infer_respects_strong_unary() {
let config = CrfConfig {
n_labels: 2,
n_iter: 5,
spatial_weight: 0.0,
bilateral_weight: 0.0,
..Default::default()
};
let unary: Vec<Vec<f64>> = vec![
vec![0.01, 100.0], vec![0.01, 100.0], vec![100.0, 0.01], vec![100.0, 0.01], ];
let image: Vec<Vec<[f64; 3]>> = vec![vec![[0.0; 3], [0.0; 3]], vec![[0.0; 3], [0.0; 3]]];
let crf = DenseCrf::new(config).set_unary(unary).set_image_2d(&image);
let labels = crf.infer();
assert_eq!(labels[0], 0);
assert_eq!(labels[1], 0);
assert_eq!(labels[2], 1);
assert_eq!(labels[3], 1);
}
#[test]
fn test_apply_to_segmentation_2d_shape() {
let rows = 3usize;
let cols = 4usize;
let n_labels = 2usize;
let unary_log_prob: Vec<Vec<Vec<f64>>> = (0..rows)
.map(|_| {
(0..cols)
.map(|c| {
if c < cols / 2 {
vec![0.1f64, 5.0]
} else {
vec![5.0, 0.1]
}
})
.collect()
})
.collect();
let image: Vec<Vec<[f64; 3]>> = vec![vec![[128.0; 3]; cols]; rows];
let config = CrfConfig {
n_labels,
n_iter: 2,
..Default::default()
};
let result = apply_to_segmentation_2d(&unary_log_prob, &image, &config);
assert_eq!(result.len(), rows);
assert_eq!(result[0].len(), cols);
for row in &result {
for &l in row {
assert!(l < n_labels, "label {l} out of range");
}
}
}
#[test]
fn test_empty_input() {
let config = CrfConfig {
n_labels: 2,
..Default::default()
};
let crf = DenseCrf::new(config);
let labels = crf.infer();
assert!(labels.is_empty());
}
}