use crate::error::{MinCutError, Result};
use crate::graph::{DynamicGraph, Edge, EdgeId, VertexId, Weight};
use crate::jtree::JTreeError;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct LevelConfig {
pub level: usize,
pub alpha: f64,
pub enable_cache: bool,
pub max_cache_entries: usize,
pub wasm_available: bool,
}
impl Default for LevelConfig {
fn default() -> Self {
Self {
level: 0,
alpha: 2.0,
enable_cache: true,
max_cache_entries: 10_000,
wasm_available: false, }
}
}
#[derive(Debug, Clone, Default)]
pub struct LevelStatistics {
pub vertex_count: usize,
pub edge_count: usize,
pub cache_hits: usize,
pub cache_misses: usize,
pub total_queries: usize,
pub avg_query_time_us: f64,
}
#[derive(Debug, Clone)]
pub struct PathCutResult {
pub value: f64,
pub source: VertexId,
pub target: VertexId,
pub from_cache: bool,
pub compute_time_us: f64,
}
#[derive(Debug, Clone)]
pub struct ContractedGraph {
vertex_map: HashMap<VertexId, VertexId>,
super_vertices: HashMap<VertexId, HashSet<VertexId>>,
edges: HashMap<(VertexId, VertexId), Weight>,
next_super_id: VertexId,
level: usize,
}
impl ContractedGraph {
pub fn from_graph(graph: &DynamicGraph, level: usize) -> Self {
let mut contracted = Self {
vertex_map: HashMap::new(),
super_vertices: HashMap::new(),
edges: HashMap::new(),
next_super_id: 0,
level,
};
for v in graph.vertices() {
contracted.vertex_map.insert(v, v);
contracted.super_vertices.insert(v, {
let mut set = HashSet::new();
set.insert(v);
set
});
contracted.next_super_id = contracted.next_super_id.max(v + 1);
}
for edge in graph.edges() {
let key = Self::canonical_key(edge.source, edge.target);
*contracted.edges.entry(key).or_insert(0.0) += edge.weight;
}
contracted
}
pub fn new(level: usize) -> Self {
Self {
vertex_map: HashMap::new(),
super_vertices: HashMap::new(),
edges: HashMap::new(),
next_super_id: 0,
level,
}
}
fn canonical_key(u: VertexId, v: VertexId) -> (VertexId, VertexId) {
if u <= v {
(u, v)
} else {
(v, u)
}
}
pub fn contract(&mut self, u: VertexId, v: VertexId) -> Result<VertexId> {
let u_super = *self
.vertex_map
.get(&u)
.ok_or_else(|| JTreeError::VertexNotFound(u))?;
let v_super = *self
.vertex_map
.get(&v)
.ok_or_else(|| JTreeError::VertexNotFound(v))?;
if u_super == v_super {
return Ok(u_super); }
let new_super = self.next_super_id;
self.next_super_id += 1;
let u_vertices = self.super_vertices.remove(&u_super).unwrap_or_default();
let v_vertices = self.super_vertices.remove(&v_super).unwrap_or_default();
let mut merged: HashSet<VertexId> = u_vertices.union(&v_vertices).copied().collect();
for &orig_v in &merged {
self.vertex_map.insert(orig_v, new_super);
}
self.super_vertices.insert(new_super, merged);
let mut new_edges = HashMap::new();
for ((src, dst), weight) in self.edges.drain() {
let new_src = if src == u_super || src == v_super {
new_super
} else {
src
};
let new_dst = if dst == u_super || dst == v_super {
new_super
} else {
dst
};
if new_src == new_dst {
continue;
}
let key = Self::canonical_key(new_src, new_dst);
*new_edges.entry(key).or_insert(0.0) += weight;
}
self.edges = new_edges;
Ok(new_super)
}
pub fn vertex_count(&self) -> usize {
self.super_vertices.len()
}
pub fn edge_count(&self) -> usize {
self.edges.len()
}
pub fn edges(&self) -> impl Iterator<Item = (VertexId, VertexId, Weight)> + '_ {
self.edges.iter().map(|(&(u, v), &w)| (u, v, w))
}
pub fn get_super_vertex(&self, v: VertexId) -> Option<VertexId> {
self.vertex_map.get(&v).copied()
}
pub fn get_original_vertices(&self, super_v: VertexId) -> Option<&HashSet<VertexId>> {
self.super_vertices.get(&super_v)
}
pub fn super_vertices(&self) -> impl Iterator<Item = VertexId> + '_ {
self.super_vertices.keys().copied()
}
pub fn edge_weight(&self, u: VertexId, v: VertexId) -> Option<Weight> {
let key = Self::canonical_key(u, v);
self.edges.get(&key).copied()
}
pub fn level(&self) -> usize {
self.level
}
}
pub trait JTreeLevel: Send + Sync {
fn level(&self) -> usize;
fn statistics(&self) -> LevelStatistics;
fn min_cut(&mut self, s: VertexId, t: VertexId) -> Result<PathCutResult>;
fn multi_terminal_cut(&mut self, terminals: &[VertexId]) -> Result<f64>;
fn refine_cut(&mut self, coarse_partition: &HashSet<VertexId>) -> Result<HashSet<VertexId>>;
fn insert_edge(&mut self, u: VertexId, v: VertexId, weight: Weight) -> Result<()>;
fn delete_edge(&mut self, u: VertexId, v: VertexId) -> Result<()>;
fn invalidate_cache(&mut self);
fn contracted_graph(&self) -> &ContractedGraph;
}
pub struct BmsspJTreeLevel {
contracted: ContractedGraph,
config: LevelConfig,
stats: LevelStatistics,
cache: HashMap<(VertexId, VertexId), PathCutResult>,
#[allow(dead_code)]
wasm_handle: Option<WasmGraphHandle>,
}
#[derive(Debug)]
pub struct WasmGraphHandle {
#[allow(dead_code)]
ptr: usize,
#[allow(dead_code)]
vertex_count: u32,
#[allow(dead_code)]
valid: bool,
}
impl WasmGraphHandle {
#[allow(dead_code)]
fn new(_vertex_count: u32) -> Result<Self> {
Ok(Self {
ptr: 0,
vertex_count: _vertex_count,
valid: false,
})
}
#[allow(dead_code)]
fn is_available() -> bool {
cfg!(feature = "wasm")
}
}
impl BmsspJTreeLevel {
pub fn new(contracted: ContractedGraph, config: LevelConfig) -> Result<Self> {
let stats = LevelStatistics {
vertex_count: contracted.vertex_count(),
edge_count: contracted.edge_count(),
..Default::default()
};
let wasm_handle = if config.wasm_available {
WasmGraphHandle::new(contracted.vertex_count() as u32).ok()
} else {
None
};
Ok(Self {
contracted,
config,
stats,
cache: HashMap::new(),
wasm_handle,
})
}
pub fn from_contracted(contracted: ContractedGraph, level: usize) -> Self {
let config = LevelConfig {
level,
..Default::default()
};
Self {
stats: LevelStatistics {
vertex_count: contracted.vertex_count(),
edge_count: contracted.edge_count(),
..Default::default()
},
contracted,
config,
cache: HashMap::new(),
wasm_handle: None,
}
}
fn native_shortest_paths(&self, source: VertexId) -> HashMap<VertexId, f64> {
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Debug)]
struct State {
cost: f64,
vertex: VertexId,
}
impl PartialEq for State {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost && self.vertex == other.vertex
}
}
impl Eq for State {}
impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
let mut distances: HashMap<VertexId, f64> = HashMap::new();
let mut heap = BinaryHeap::new();
let mut adj: HashMap<VertexId, Vec<(VertexId, f64)>> = HashMap::new();
for (u, v, w) in self.contracted.edges() {
adj.entry(u).or_default().push((v, w));
adj.entry(v).or_default().push((u, w));
}
distances.insert(source, 0.0);
heap.push(State {
cost: 0.0,
vertex: source,
});
while let Some(State { cost, vertex }) = heap.pop() {
if let Some(&d) = distances.get(&vertex) {
if cost > d {
continue;
}
}
if let Some(neighbors) = adj.get(&vertex) {
for &(next, edge_weight) in neighbors {
let next_cost = cost + edge_weight;
let is_better = distances.get(&next).map(|&d| next_cost < d).unwrap_or(true);
if is_better {
distances.insert(next, next_cost);
heap.push(State {
cost: next_cost,
vertex: next,
});
}
}
}
}
distances
}
fn cache_key(s: VertexId, t: VertexId) -> (VertexId, VertexId) {
if s <= t {
(s, t)
} else {
(t, s)
}
}
fn update_stats(&mut self, from_cache: bool, compute_time_us: f64) {
self.stats.total_queries += 1;
if from_cache {
self.stats.cache_hits += 1;
} else {
self.stats.cache_misses += 1;
}
let n = self.stats.total_queries as f64;
self.stats.avg_query_time_us =
(self.stats.avg_query_time_us * (n - 1.0) + compute_time_us) / n;
}
}
impl JTreeLevel for BmsspJTreeLevel {
fn level(&self) -> usize {
self.config.level
}
fn statistics(&self) -> LevelStatistics {
self.stats.clone()
}
fn min_cut(&mut self, s: VertexId, t: VertexId) -> Result<PathCutResult> {
let start = std::time::Instant::now();
let key = Self::cache_key(s, t);
if self.config.enable_cache {
if let Some(cached) = self.cache.get(&key) {
let mut result = cached.clone();
result.from_cache = true;
self.update_stats(true, start.elapsed().as_micros() as f64);
return Ok(result);
}
}
let s_super = self
.contracted
.get_super_vertex(s)
.ok_or_else(|| JTreeError::VertexNotFound(s))?;
let t_super = self
.contracted
.get_super_vertex(t)
.ok_or_else(|| JTreeError::VertexNotFound(t))?;
if s_super == t_super {
let result = PathCutResult {
value: f64::INFINITY,
source: s,
target: t,
from_cache: false,
compute_time_us: start.elapsed().as_micros() as f64,
};
self.update_stats(false, result.compute_time_us);
return Ok(result);
}
let distances = self.native_shortest_paths(s_super);
let cut_value = distances.get(&t_super).copied().unwrap_or(f64::INFINITY);
let compute_time = start.elapsed().as_micros() as f64;
let result = PathCutResult {
value: cut_value,
source: s,
target: t,
from_cache: false,
compute_time_us: compute_time,
};
if self.config.enable_cache {
if self.config.max_cache_entries > 0
&& self.cache.len() >= self.config.max_cache_entries
{
let keys_to_remove: Vec<_> = self
.cache
.keys()
.take(self.config.max_cache_entries / 2)
.copied()
.collect();
for k in keys_to_remove {
self.cache.remove(&k);
}
}
self.cache.insert(key, result.clone());
}
self.update_stats(false, compute_time);
Ok(result)
}
fn multi_terminal_cut(&mut self, terminals: &[VertexId]) -> Result<f64> {
if terminals.len() < 2 {
return Ok(f64::INFINITY);
}
let mut min_cut = f64::INFINITY;
for i in 0..terminals.len() {
for j in (i + 1)..terminals.len() {
let result = self.min_cut(terminals[i], terminals[j])?;
min_cut = min_cut.min(result.value);
}
}
Ok(min_cut)
}
fn refine_cut(&mut self, coarse_partition: &HashSet<VertexId>) -> Result<HashSet<VertexId>> {
let mut refined = HashSet::new();
for &super_v in coarse_partition {
if let Some(original_vertices) = self.contracted.get_original_vertices(super_v) {
refined.extend(original_vertices);
}
}
Ok(refined)
}
fn insert_edge(&mut self, u: VertexId, v: VertexId, weight: Weight) -> Result<()> {
let u_super = self
.contracted
.get_super_vertex(u)
.ok_or_else(|| JTreeError::VertexNotFound(u))?;
let v_super = self
.contracted
.get_super_vertex(v)
.ok_or_else(|| JTreeError::VertexNotFound(v))?;
if u_super != v_super {
let key = ContractedGraph::canonical_key(u_super, v_super);
*self.contracted.edges.entry(key).or_insert(0.0) += weight;
self.stats.edge_count = self.contracted.edge_count();
}
self.invalidate_cache();
Ok(())
}
fn delete_edge(&mut self, u: VertexId, v: VertexId) -> Result<()> {
let u_super = self
.contracted
.get_super_vertex(u)
.ok_or_else(|| JTreeError::VertexNotFound(u))?;
let v_super = self
.contracted
.get_super_vertex(v)
.ok_or_else(|| JTreeError::VertexNotFound(v))?;
if u_super != v_super {
let key = ContractedGraph::canonical_key(u_super, v_super);
self.contracted.edges.remove(&key);
self.stats.edge_count = self.contracted.edge_count();
}
self.invalidate_cache();
Ok(())
}
fn invalidate_cache(&mut self) {
self.cache.clear();
}
fn contracted_graph(&self) -> &ContractedGraph {
&self.contracted
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph() -> DynamicGraph {
let graph = DynamicGraph::new();
graph.insert_edge(1, 2, 2.0).unwrap();
graph.insert_edge(2, 3, 1.0).unwrap(); graph.insert_edge(3, 4, 2.0).unwrap();
graph
}
#[test]
fn test_contracted_graph_from_graph() {
let graph = create_test_graph();
let contracted = ContractedGraph::from_graph(&graph, 0);
assert_eq!(contracted.vertex_count(), 4);
assert_eq!(contracted.edge_count(), 3);
assert_eq!(contracted.level(), 0);
}
#[test]
fn test_contracted_graph_contract() {
let graph = create_test_graph();
let mut contracted = ContractedGraph::from_graph(&graph, 0);
let super_v = contracted.contract(1, 2).unwrap();
assert_eq!(contracted.vertex_count(), 3);
let original = contracted.get_original_vertices(super_v).unwrap();
assert!(original.contains(&1));
assert!(original.contains(&2));
}
#[test]
fn test_bmssp_level_creation() {
let graph = create_test_graph();
let contracted = ContractedGraph::from_graph(&graph, 0);
let config = LevelConfig::default();
let level = BmsspJTreeLevel::new(contracted, config).unwrap();
assert_eq!(level.level(), 0);
let stats = level.statistics();
assert_eq!(stats.vertex_count, 4);
assert_eq!(stats.edge_count, 3);
}
#[test]
fn test_min_cut_query() {
let graph = create_test_graph();
let contracted = ContractedGraph::from_graph(&graph, 0);
let mut level = BmsspJTreeLevel::from_contracted(contracted, 0);
let result = level.min_cut(1, 4).unwrap();
assert!(result.value.is_finite());
assert!(!result.from_cache);
}
#[test]
fn test_min_cut_caching() {
let graph = create_test_graph();
let contracted = ContractedGraph::from_graph(&graph, 0);
let mut level = BmsspJTreeLevel::from_contracted(contracted, 0);
let result1 = level.min_cut(1, 4).unwrap();
assert!(!result1.from_cache);
let result2 = level.min_cut(1, 4).unwrap();
assert!(result2.from_cache);
assert_eq!(result1.value, result2.value);
let result3 = level.min_cut(4, 1).unwrap();
assert!(result3.from_cache);
}
#[test]
fn test_multi_terminal_cut() {
let graph = create_test_graph();
let contracted = ContractedGraph::from_graph(&graph, 0);
let mut level = BmsspJTreeLevel::from_contracted(contracted, 0);
let terminals = vec![1, 2, 3, 4];
let cut = level.multi_terminal_cut(&terminals).unwrap();
assert!(cut.is_finite());
}
#[test]
fn test_cache_invalidation() {
let graph = create_test_graph();
let contracted = ContractedGraph::from_graph(&graph, 0);
let mut level = BmsspJTreeLevel::from_contracted(contracted, 0);
let _ = level.min_cut(1, 4).unwrap();
assert_eq!(level.statistics().cache_hits, 0);
let _ = level.min_cut(1, 4).unwrap();
assert_eq!(level.statistics().cache_hits, 1);
level.invalidate_cache();
let result = level.min_cut(1, 4).unwrap();
assert!(!result.from_cache);
}
#[test]
fn test_level_config_default() {
let config = LevelConfig::default();
assert_eq!(config.level, 0);
assert_eq!(config.alpha, 2.0);
assert!(config.enable_cache);
assert_eq!(config.max_cache_entries, 10_000);
}
#[test]
fn test_refine_cut() {
let graph = create_test_graph();
let mut contracted = ContractedGraph::from_graph(&graph, 0);
let super_12 = contracted.contract(1, 2).unwrap();
let mut level = BmsspJTreeLevel::from_contracted(contracted, 0);
let coarse: HashSet<VertexId> = vec![super_12].into_iter().collect();
let refined = level.refine_cut(&coarse).unwrap();
assert!(refined.contains(&1));
assert!(refined.contains(&2));
assert!(!refined.contains(&3));
assert!(!refined.contains(&4));
}
}