use crate::error::{MinCutError, Result};
use crate::graph::{DynamicGraph, VertexId, Weight};
use crate::jtree::level::{BmsspJTreeLevel, ContractedGraph, JTreeLevel, LevelConfig};
use crate::jtree::sparsifier::{DynamicCutSparsifier, SparsifierConfig};
use crate::jtree::{compute_alpha, compute_num_levels, validate_config, JTreeError};
use std::collections::HashSet;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct JTreeConfig {
pub epsilon: f64,
pub critical_threshold: f64,
pub max_approximation_factor: f64,
pub lazy_evaluation: bool,
pub enable_caching: bool,
pub max_cache_per_level: usize,
pub wasm_available: bool,
pub sparsifier: SparsifierConfig,
}
impl Default for JTreeConfig {
fn default() -> Self {
Self {
epsilon: 0.5,
critical_threshold: 10.0,
max_approximation_factor: 10.0,
lazy_evaluation: true,
enable_caching: true,
max_cache_per_level: 10_000,
wasm_available: false,
sparsifier: SparsifierConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct ApproximateCut {
pub value: f64,
pub approximation_factor: f64,
pub partition: HashSet<VertexId>,
pub source_level: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tier {
Approximate,
Exact,
}
#[derive(Debug, Clone)]
pub struct CutResult {
pub value: f64,
pub partition: HashSet<VertexId>,
pub is_exact: bool,
pub approximation_factor: f64,
pub tier_used: Tier,
}
impl CutResult {
pub fn exact(value: f64, partition: HashSet<VertexId>) -> Self {
Self {
value,
partition,
is_exact: true,
approximation_factor: 1.0,
tier_used: Tier::Exact,
}
}
pub fn approximate(
value: f64,
factor: f64,
partition: HashSet<VertexId>,
level: usize,
) -> Self {
Self {
value,
partition,
is_exact: false,
approximation_factor: factor,
tier_used: Tier::Approximate,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct JTreeStatistics {
pub num_levels: usize,
pub total_vertices: usize,
pub total_edges: usize,
pub approx_queries: usize,
pub exact_queries: usize,
pub cache_hit_rate: f64,
pub total_recourse: usize,
}
enum LevelState {
Unmaterialized,
Materialized(Box<dyn JTreeLevel>),
Dirty(Box<dyn JTreeLevel>),
}
impl std::fmt::Debug for LevelState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unmaterialized => write!(f, "Unmaterialized"),
Self::Materialized(l) => write!(f, "Materialized(level={})", l.level()),
Self::Dirty(l) => write!(f, "Dirty(level={})", l.level()),
}
}
}
pub struct JTreeHierarchy {
config: JTreeConfig,
alpha: f64,
num_levels: usize,
levels: Vec<LevelState>,
sparsifier: DynamicCutSparsifier,
graph: Arc<DynamicGraph>,
stats: JTreeStatistics,
dirty_levels: HashSet<usize>,
}
impl JTreeHierarchy {
pub fn build(graph: Arc<DynamicGraph>, config: JTreeConfig) -> Result<Self> {
validate_config(&config)?;
let alpha = compute_alpha(config.epsilon);
let num_levels = compute_num_levels(graph.num_vertices(), alpha);
let sparsifier = DynamicCutSparsifier::build(&graph, config.sparsifier.clone())?;
let levels = if config.lazy_evaluation {
(0..num_levels)
.map(|_| LevelState::Unmaterialized)
.collect()
} else {
Self::build_all_levels(&graph, num_levels, alpha, &config)?
};
let stats = JTreeStatistics {
num_levels,
total_vertices: graph.num_vertices() * num_levels, ..Default::default()
};
Ok(Self {
config,
alpha,
num_levels,
levels,
sparsifier,
graph,
stats,
dirty_levels: HashSet::new(),
})
}
fn build_all_levels(
graph: &DynamicGraph,
num_levels: usize,
alpha: f64,
config: &JTreeConfig,
) -> Result<Vec<LevelState>> {
let mut levels = Vec::with_capacity(num_levels);
let mut current = ContractedGraph::from_graph(graph, 0);
for level_idx in 0..num_levels {
let level_config = LevelConfig {
level: level_idx,
alpha,
enable_cache: config.enable_caching,
max_cache_entries: config.max_cache_per_level,
wasm_available: config.wasm_available,
};
let level = BmsspJTreeLevel::new(current.clone(), level_config)?;
levels.push(LevelState::Materialized(Box::new(level)));
if level_idx + 1 < num_levels {
current = Self::contract_level(¤t, alpha)?;
}
}
Ok(levels)
}
fn contract_level(current: &ContractedGraph, alpha: f64) -> Result<ContractedGraph> {
let mut contracted = current.clone();
let target_size = (current.vertex_count() as f64 / alpha).ceil() as usize;
let target_size = target_size.max(1);
let super_vertices: Vec<VertexId> = contracted.super_vertices().collect();
let mut i = 0;
while contracted.vertex_count() > target_size && i < super_vertices.len() {
let v = super_vertices[i];
let neighbor = contracted
.edges()
.filter_map(|(u, w, _)| {
if u == v {
Some(w)
} else if w == v {
Some(u)
} else {
None
}
})
.next();
if let Some(neighbor) = neighbor {
let _ = contracted.contract(v, neighbor);
}
i += 1;
}
Ok(ContractedGraph::new(current.level() + 1))
}
fn ensure_materialized(&mut self, level: usize) -> Result<()> {
if level >= self.num_levels {
return Err(JTreeError::LevelOutOfBounds {
level,
max_level: self.num_levels - 1,
}
.into());
}
match &self.levels[level] {
LevelState::Materialized(_) => Ok(()),
LevelState::Unmaterialized | LevelState::Dirty(_) => {
let contracted = self.build_level_contracted(level)?;
let level_config = LevelConfig {
level,
alpha: self.alpha,
enable_cache: self.config.enable_caching,
max_cache_entries: self.config.max_cache_per_level,
wasm_available: self.config.wasm_available,
};
let new_level = BmsspJTreeLevel::new(contracted, level_config)?;
self.levels[level] = LevelState::Materialized(Box::new(new_level));
self.dirty_levels.remove(&level);
Ok(())
}
}
}
fn build_level_contracted(&self, level: usize) -> Result<ContractedGraph> {
let mut current = ContractedGraph::from_graph(&self.graph, 0);
for l in 0..level {
current = Self::contract_level(¤t, self.alpha)?;
}
Ok(current)
}
fn get_level_mut(&mut self, level: usize) -> Result<&mut Box<dyn JTreeLevel>> {
self.ensure_materialized(level)?;
match &mut self.levels[level] {
LevelState::Materialized(l) => Ok(l),
_ => Err(JTreeError::LevelOutOfBounds {
level,
max_level: self.num_levels - 1,
}
.into()),
}
}
pub fn approximate_min_cut(&mut self) -> Result<ApproximateCut> {
self.stats.approx_queries += 1;
if self.num_levels == 0 {
return Ok(ApproximateCut {
value: f64::INFINITY,
approximation_factor: 1.0,
partition: HashSet::new(),
source_level: 0,
});
}
let mut best_value = f64::INFINITY;
let mut best_partition = HashSet::new();
let mut best_level = 0;
for level in (0..self.num_levels).rev() {
self.ensure_materialized(level)?;
if let LevelState::Materialized(ref mut l) = &mut self.levels[level] {
let contracted = l.contracted_graph();
let vertices: Vec<VertexId> = contracted.super_vertices().collect();
if vertices.len() < 2 {
continue;
}
let cut_value = l.multi_terminal_cut(&vertices)?;
if cut_value < best_value {
best_value = cut_value;
best_level = level;
let half = vertices.len() / 2;
let coarse_partition: HashSet<VertexId> =
vertices.into_iter().take(half).collect();
best_partition = l.refine_cut(&coarse_partition)?;
}
}
}
let approximation_factor = self.alpha.powi(best_level as i32);
Ok(ApproximateCut {
value: best_value,
approximation_factor,
partition: best_partition,
source_level: best_level,
})
}
pub fn min_cut(&mut self, exact_required: bool) -> Result<CutResult> {
let approx = self.approximate_min_cut()?;
let should_escalate = exact_required
|| approx.value < self.config.critical_threshold
|| approx.approximation_factor > self.config.max_approximation_factor;
if should_escalate {
self.stats.exact_queries += 1;
Ok(CutResult {
value: approx.value,
partition: approx.partition,
is_exact: false, approximation_factor: approx.approximation_factor,
tier_used: Tier::Approximate, })
} else {
Ok(CutResult::approximate(
approx.value,
approx.approximation_factor,
approx.partition,
approx.source_level,
))
}
}
pub fn insert_edge(&mut self, u: VertexId, v: VertexId, weight: Weight) -> Result<f64> {
self.sparsifier.insert_edge(u, v, weight)?;
self.stats.total_recourse += self.sparsifier.last_recourse();
for level in 0..self.num_levels {
if let LevelState::Materialized(_) = &self.levels[level] {
self.dirty_levels.insert(level);
self.levels[level] =
match std::mem::replace(&mut self.levels[level], LevelState::Unmaterialized) {
LevelState::Materialized(l) => LevelState::Dirty(l),
other => other,
};
}
}
for level in 0..self.num_levels {
if self.dirty_levels.contains(&level) {
if let LevelState::Dirty(ref mut l) = &mut self.levels[level] {
l.insert_edge(u, v, weight)?;
l.invalidate_cache();
}
}
}
let approx = self.approximate_min_cut()?;
Ok(approx.value)
}
pub fn delete_edge(&mut self, u: VertexId, v: VertexId) -> Result<f64> {
self.sparsifier.delete_edge(u, v)?;
self.stats.total_recourse += self.sparsifier.last_recourse();
for level in 0..self.num_levels {
if let LevelState::Materialized(_) = &self.levels[level] {
self.dirty_levels.insert(level);
self.levels[level] =
match std::mem::replace(&mut self.levels[level], LevelState::Unmaterialized) {
LevelState::Materialized(l) => LevelState::Dirty(l),
other => other,
};
}
}
for level in 0..self.num_levels {
if self.dirty_levels.contains(&level) {
if let LevelState::Dirty(ref mut l) = &mut self.levels[level] {
l.delete_edge(u, v)?;
l.invalidate_cache();
}
}
}
let approx = self.approximate_min_cut()?;
Ok(approx.value)
}
pub fn statistics(&self) -> JTreeStatistics {
let mut stats = self.stats.clone();
let mut total_v = 0;
let mut total_e = 0;
let mut cache_hits = 0;
let mut cache_total = 0;
for level in &self.levels {
if let LevelState::Materialized(l) | LevelState::Dirty(l) = level {
let ls = l.statistics();
total_v += ls.vertex_count;
total_e += ls.edge_count;
cache_hits += ls.cache_hits;
cache_total += ls.total_queries;
}
}
stats.total_vertices = total_v;
stats.total_edges = total_e;
stats.cache_hit_rate = if cache_total > 0 {
cache_hits as f64 / cache_total as f64
} else {
0.0
};
stats
}
pub fn num_levels(&self) -> usize {
self.num_levels
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn config(&self) -> &JTreeConfig {
&self.config
}
pub fn approximation_factor(&self) -> f64 {
self.alpha.powi(self.num_levels as i32)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph() -> Arc<DynamicGraph> {
let graph = Arc::new(DynamicGraph::new());
graph.insert_edge(1, 2, 2.0).unwrap();
graph.insert_edge(2, 3, 2.0).unwrap();
graph.insert_edge(3, 1, 2.0).unwrap();
graph.insert_edge(3, 4, 1.0).unwrap(); graph.insert_edge(4, 5, 2.0).unwrap();
graph.insert_edge(5, 6, 2.0).unwrap();
graph.insert_edge(6, 4, 2.0).unwrap();
graph
}
#[test]
fn test_hierarchy_build() {
let graph = create_test_graph();
let config = JTreeConfig::default();
let hierarchy = JTreeHierarchy::build(graph.clone(), config).unwrap();
assert!(hierarchy.num_levels() > 0);
assert!(hierarchy.alpha() > 1.0);
}
#[test]
fn test_hierarchy_build_eager() {
let graph = create_test_graph();
let config = JTreeConfig {
lazy_evaluation: false,
..Default::default()
};
let hierarchy = JTreeHierarchy::build(graph.clone(), config).unwrap();
for level in &hierarchy.levels {
assert!(matches!(level, LevelState::Materialized(_)));
}
}
#[test]
fn test_approximate_min_cut() {
let graph = create_test_graph();
let config = JTreeConfig::default();
let mut hierarchy = JTreeHierarchy::build(graph.clone(), config).unwrap();
let approx = hierarchy.approximate_min_cut().unwrap();
assert!(approx.value.is_finite());
assert!(approx.approximation_factor >= 1.0);
assert!(!approx.partition.is_empty());
}
#[test]
fn test_two_tier_min_cut() {
let graph = create_test_graph();
let config = JTreeConfig {
critical_threshold: 0.5, ..Default::default()
};
let mut hierarchy = JTreeHierarchy::build(graph.clone(), config).unwrap();
let result = hierarchy.min_cut(false).unwrap();
assert_eq!(result.tier_used, Tier::Approximate);
let result = hierarchy.min_cut(true).unwrap();
}
#[test]
fn test_insert_edge() {
let graph = create_test_graph();
let config = JTreeConfig {
lazy_evaluation: false, ..Default::default()
};
let mut hierarchy = JTreeHierarchy::build(graph.clone(), config).unwrap();
let old_cut = hierarchy.approximate_min_cut().unwrap().value;
graph.insert_edge(1, 5, 5.0).unwrap();
assert!(old_cut.is_finite() || old_cut.is_infinite());
}
#[test]
fn test_delete_edge() {
let graph = create_test_graph();
let config = JTreeConfig::default();
let mut hierarchy = JTreeHierarchy::build(graph.clone(), config).unwrap();
graph.delete_edge(3, 4).unwrap();
let new_cut = hierarchy.delete_edge(3, 4).unwrap();
}
#[test]
fn test_statistics() {
let graph = create_test_graph();
let config = JTreeConfig {
lazy_evaluation: false,
..Default::default()
};
let mut hierarchy = JTreeHierarchy::build(graph.clone(), config).unwrap();
let _ = hierarchy.approximate_min_cut();
let _ = hierarchy.min_cut(false);
let stats = hierarchy.statistics();
assert_eq!(stats.num_levels, hierarchy.num_levels());
assert!(stats.approx_queries > 0);
}
#[test]
fn test_config_validation() {
let graph = create_test_graph();
let config = JTreeConfig {
epsilon: 0.0,
..Default::default()
};
assert!(JTreeHierarchy::build(graph.clone(), config).is_err());
let config = JTreeConfig {
epsilon: 1.5,
..Default::default()
};
assert!(JTreeHierarchy::build(graph.clone(), config).is_err());
let config = JTreeConfig {
epsilon: 0.5,
..Default::default()
};
assert!(JTreeHierarchy::build(graph.clone(), config).is_ok());
}
#[test]
fn test_approximation_factor() {
let graph = create_test_graph();
let config = JTreeConfig {
epsilon: 0.5, ..Default::default()
};
let hierarchy = JTreeHierarchy::build(graph.clone(), config).unwrap();
let expected = hierarchy.alpha().powi(hierarchy.num_levels() as i32);
let actual = hierarchy.approximation_factor();
assert!((actual - expected).abs() < 1e-10);
}
#[test]
fn test_cut_result_helpers() {
let partition: HashSet<VertexId> = vec![1, 2, 3].into_iter().collect();
let exact = CutResult::exact(5.0, partition.clone());
assert!(exact.is_exact);
assert_eq!(exact.approximation_factor, 1.0);
assert_eq!(exact.tier_used, Tier::Exact);
let approx = CutResult::approximate(6.0, 2.0, partition.clone(), 1);
assert!(!approx.is_exact);
assert_eq!(approx.approximation_factor, 2.0);
assert_eq!(approx.tier_used, Tier::Approximate);
}
}