use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RagConfig {
pub connectivity: u8,
pub merge_threshold: f64,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
connectivity: 4,
merge_threshold: 10.0,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RagEdge {
pub region1: usize,
pub region2: usize,
pub weight: f64,
}
#[derive(Debug, Clone)]
pub struct RegionAdjacencyGraph {
pub regions: Vec<usize>,
pub edges: Vec<RagEdge>,
}
pub fn build_rag_2d(
labels: &[Vec<usize>],
image: &[Vec<f64>],
config: &RagConfig,
) -> RegionAdjacencyGraph {
let rows = labels.len();
if rows == 0 {
return RegionAdjacencyGraph {
regions: vec![],
edges: vec![],
};
}
let cols = labels[0].len();
let mut edge_acc: HashMap<(usize, usize), (f64, usize)> = HashMap::new();
let mut region_set: std::collections::BTreeSet<usize> = std::collections::BTreeSet::new();
let use_8conn = config.connectivity == 8;
for r in 0..rows {
for c in 0..cols {
let la = labels[r][c];
let ia = image[r][c];
region_set.insert(la);
let neighbours_4 = [(0isize, 1isize), (1, 0)];
let neighbours_diag = [(1isize, 1isize), (1, -1isize)];
let check_neighbour =
|nr: isize, nc: isize, acc: &mut HashMap<(usize, usize), (f64, usize)>| {
if nr < 0 || nr >= rows as isize || nc < 0 || nc >= cols as isize {
return;
}
let lb = labels[nr as usize][nc as usize];
if la == lb {
return;
}
let ib = image[nr as usize][nc as usize];
let key = if la < lb { (la, lb) } else { (lb, la) };
let diff = (ia - ib).abs();
let entry = acc.entry(key).or_insert((0.0, 0));
entry.0 += diff;
entry.1 += 1;
};
for (dr, dc) in &neighbours_4 {
let nr = r as isize + dr;
let nc = c as isize + dc;
check_neighbour(nr, nc, &mut edge_acc);
}
if use_8conn {
for (dr, dc) in &neighbours_diag {
let nr = r as isize + dr;
let nc = c as isize + dc;
check_neighbour(nr, nc, &mut edge_acc);
}
}
}
}
let regions: Vec<usize> = region_set.into_iter().collect();
let edges = edge_acc
.into_iter()
.map(|((r1, r2), (sum, cnt))| RagEdge {
region1: r1,
region2: r2,
weight: if cnt > 0 { sum / cnt as f64 } else { 0.0 },
})
.collect();
RegionAdjacencyGraph { regions, edges }
}
pub fn build_rag_3d(
labels: &[Vec<Vec<usize>>],
image: &[Vec<Vec<f64>>],
config: &RagConfig,
) -> RegionAdjacencyGraph {
let nz = labels.len();
if nz == 0 {
return RegionAdjacencyGraph {
regions: vec![],
edges: vec![],
};
}
let ny = labels[0].len();
let nx = if ny > 0 { labels[0][0].len() } else { 0 };
let mut edge_acc: HashMap<(usize, usize), (f64, usize)> = HashMap::new();
let mut region_set: std::collections::BTreeSet<usize> = std::collections::BTreeSet::new();
let use_26conn = config.connectivity == 26;
for z in 0..nz {
for y in 0..ny {
for x in 0..nx {
let la = labels[z][y][x];
let ia = image[z][y][x];
region_set.insert(la);
let face_offsets: &[(isize, isize, isize)] = &[(1, 0, 0), (0, 1, 0), (0, 0, 1)];
let extra_offsets: &[(isize, isize, isize)] = &[
(1, 1, 0),
(1, -1, 0),
(1, 0, 1),
(1, 0, -1),
(0, 1, 1),
(0, 1, -1),
(1, 1, 1),
(1, 1, -1),
(1, -1, 1),
(1, -1, -1),
];
let check =
|nz2: isize,
ny2: isize,
nx2: isize,
acc: &mut HashMap<(usize, usize), (f64, usize)>| {
if nz2 < 0
|| nz2 >= nz as isize
|| ny2 < 0
|| ny2 >= ny as isize
|| nx2 < 0
|| nx2 >= nx as isize
{
return;
}
let lb = labels[nz2 as usize][ny2 as usize][nx2 as usize];
if la == lb {
return;
}
let ib = image[nz2 as usize][ny2 as usize][nx2 as usize];
let key = if la < lb { (la, lb) } else { (lb, la) };
let diff = (ia - ib).abs();
let entry = acc.entry(key).or_insert((0.0, 0));
entry.0 += diff;
entry.1 += 1;
};
for (dz, dy, dx) in face_offsets {
check(
z as isize + dz,
y as isize + dy,
x as isize + dx,
&mut edge_acc,
);
}
if use_26conn {
for (dz, dy, dx) in extra_offsets {
check(
z as isize + dz,
y as isize + dy,
x as isize + dx,
&mut edge_acc,
);
}
}
}
}
}
let regions: Vec<usize> = region_set.into_iter().collect();
let edges = edge_acc
.into_iter()
.map(|((r1, r2), (sum, cnt))| RagEdge {
region1: r1,
region2: r2,
weight: if cnt > 0 { sum / cnt as f64 } else { 0.0 },
})
.collect();
RegionAdjacencyGraph { regions, edges }
}
pub fn merge_small_regions(
rag: &RegionAdjacencyGraph,
labels: &mut [Vec<usize>],
min_region_size: usize,
) -> usize {
let rows = labels.len();
if rows == 0 {
return 0;
}
let cols = labels[0].len();
let mut size_map: HashMap<usize, usize> = HashMap::new();
for row in labels.iter() {
for &lbl in row.iter() {
*size_map.entry(lbl).or_insert(0) += 1;
}
}
let mut adj: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
for edge in &rag.edges {
adj.entry(edge.region1)
.or_default()
.push((edge.region2, edge.weight));
adj.entry(edge.region2)
.or_default()
.push((edge.region1, edge.weight));
}
let mut relabel: HashMap<usize, usize> = rag.regions.iter().map(|&r| (r, r)).collect();
let resolve = |map: &HashMap<usize, usize>, mut lbl: usize| -> usize {
for _ in 0..1000 {
let next = *map.get(&lbl).unwrap_or(&lbl);
if next == lbl {
break;
}
lbl = next;
}
lbl
};
let mut n_merges = 0usize;
loop {
let candidate = size_map
.iter()
.filter(|(&lbl, &sz)| sz < min_region_size && resolve(&relabel, lbl) == lbl)
.min_by_key(|&(_, &sz)| sz)
.map(|(&lbl, _)| lbl);
let small_lbl = match candidate {
Some(l) => l,
None => break,
};
let neighbours = match adj.get(&small_lbl) {
Some(v) => v.clone(),
None => break,
};
let best_neighbour = neighbours
.iter()
.filter_map(|(nb, w)| {
let canonical = resolve(&relabel, *nb);
if canonical != small_lbl {
Some((canonical, *w))
} else {
None
}
})
.min_by(|(_, wa), (_, wb)| wa.partial_cmp(wb).unwrap_or(std::cmp::Ordering::Equal));
let target = match best_neighbour {
Some((nb, _)) => nb,
None => break, };
relabel.insert(small_lbl, target);
let small_size = *size_map.get(&small_lbl).unwrap_or(&0);
*size_map.entry(target).or_insert(0) += small_size;
size_map.remove(&small_lbl);
n_merges += 1;
}
for row in labels.iter_mut() {
for lbl in row.iter_mut() {
*lbl = resolve(&relabel, *lbl);
}
}
n_merges
}
pub fn rag_to_adjacency_matrix(rag: &RegionAdjacencyGraph) -> (Vec<usize>, Vec<Vec<f64>>) {
let n = rag.regions.len();
let mut labels = rag.regions.clone();
labels.sort_unstable();
let label_to_idx: HashMap<usize, usize> =
labels.iter().enumerate().map(|(i, &l)| (l, i)).collect();
let mut matrix = vec![vec![0.0f64; n]; n];
for edge in &rag.edges {
if let (Some(&i), Some(&j)) = (
label_to_idx.get(&edge.region1),
label_to_idx.get(&edge.region2),
) {
matrix[i][j] = edge.weight;
matrix[j][i] = edge.weight;
}
}
(labels, matrix)
}
#[cfg(test)]
mod tests {
use super::*;
fn two_region_labels() -> Vec<Vec<usize>> {
vec![
vec![1, 1, 2, 2],
vec![1, 1, 2, 2],
vec![1, 1, 2, 2],
vec![1, 1, 2, 2],
]
}
fn two_region_image() -> Vec<Vec<f64>> {
vec![
vec![10.0, 10.0, 50.0, 50.0],
vec![10.0, 10.0, 50.0, 50.0],
vec![10.0, 10.0, 50.0, 50.0],
vec![10.0, 10.0, 50.0, 50.0],
]
}
#[test]
fn test_two_regions_one_edge() {
let labels = two_region_labels();
let image = two_region_image();
let config = RagConfig::default();
let rag = build_rag_2d(&labels, &image, &config);
assert_eq!(
rag.edges.len(),
1,
"exactly 1 edge between region 1 and region 2"
);
assert_eq!(rag.regions.len(), 2);
}
#[test]
fn test_edge_weight_correct() {
let labels = two_region_labels();
let image = two_region_image();
let config = RagConfig::default();
let rag = build_rag_2d(&labels, &image, &config);
let w = rag.edges[0].weight;
assert!((w - 40.0).abs() < 1e-10, "expected weight 40.0, got {}", w);
}
#[test]
fn test_merge_small_regions() {
let mut labels = vec![
vec![1usize, 1, 1, 1],
vec![1, 1, 1, 1],
vec![1, 1, 1, 1],
vec![1, 1, 1, 3],
];
let image = vec![
vec![10.0f64; 4],
vec![10.0f64; 4],
vec![10.0f64; 4],
vec![10.0, 10.0, 10.0, 50.0],
];
let config = RagConfig {
connectivity: 4,
merge_threshold: 5.0,
};
let rag = build_rag_2d(&labels, &image, &config);
let n = merge_small_regions(&rag, &mut labels, 2);
assert_eq!(n, 1, "one merge should have happened");
assert_eq!(labels[3][3], 1);
}
#[test]
fn test_adjacency_matrix_symmetry() {
let labels = two_region_labels();
let image = two_region_image();
let config = RagConfig::default();
let rag = build_rag_2d(&labels, &image, &config);
let (_sorted_labels, matrix) = rag_to_adjacency_matrix(&rag);
for i in 0..matrix.len() {
for j in 0..matrix.len() {
assert!(
(matrix[i][j] - matrix[j][i]).abs() < 1e-12,
"matrix not symmetric at [{i}][{j}]"
);
}
}
}
#[test]
fn test_3d_rag_basic() {
let labels = vec![vec![vec![1usize, 2]], vec![vec![1usize, 2]]];
let image = vec![vec![vec![5.0f64, 15.0f64]], vec![vec![5.0f64, 15.0f64]]];
let config = RagConfig {
connectivity: 6,
..Default::default()
};
let rag = build_rag_3d(&labels, &image, &config);
assert!(!rag.edges.is_empty(), "should have at least one edge");
}
}