use std::collections::{HashMap, VecDeque};
use ahash::AHashSet;
use crate::errors::SqliteGraphError;
use crate::graph::SqliteGraph;
use crate::progress::ProgressCallback;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MinCutResult {
pub source_side: AHashSet<i64>,
pub sink_side: AHashSet<i64>,
pub cut_edges: Vec<(i64, i64)>,
pub cut_size: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MinVertexCutResult {
pub separator: AHashSet<i64>,
pub source_side: AHashSet<i64>,
pub sink_side: AHashSet<i64>,
pub cut_size: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartitionResult {
pub partitions: Vec<AHashSet<i64>>,
pub cut_edges: Vec<(i64, i64)>,
pub node_to_partition: HashMap<i64, usize>,
}
#[derive(Debug, Clone)]
pub struct PartitionConfig {
pub k: usize,
pub max_size: usize,
pub max_imbalance: f64,
pub seeds: Option<Vec<i64>>,
}
impl Default for PartitionConfig {
fn default() -> Self {
Self {
k: 2,
max_size: usize::MAX,
max_imbalance: 0.1,
seeds: None,
}
}
}
#[derive(Debug, Clone)]
struct FlowEdge {
to: i64,
capacity: usize,
flow: usize,
}
impl FlowEdge {
fn new(to: i64, capacity: usize) -> Self {
Self {
to,
capacity,
flow: 0,
}
}
fn residual(&self) -> usize {
self.capacity - self.flow
}
}
struct FlowNetwork {
adjacency: HashMap<i64, Vec<FlowEdge>>,
reverse_edge: HashMap<(i64, i64), usize>,
}
impl FlowNetwork {
fn new() -> Self {
Self {
adjacency: HashMap::new(),
reverse_edge: HashMap::new(),
}
}
fn add_edge(&mut self, from: i64, to: i64, capacity: usize) {
if from == to {
return;
}
let forward_idx = self.adjacency.entry(from).or_default().len();
let reverse_idx = self.adjacency.entry(to).or_default().len();
self.adjacency
.entry(from)
.or_default()
.push(FlowEdge::new(to, capacity));
self.adjacency
.entry(to)
.or_default()
.push(FlowEdge::new(from, 0));
self.reverse_edge.insert((from, to), reverse_idx);
self.reverse_edge.insert((to, from), forward_idx);
}
fn neighbors(&self, node: i64) -> &[FlowEdge] {
self.adjacency
.get(&node)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
fn nodes(&self) -> AHashSet<i64> {
self.adjacency.keys().copied().collect()
}
fn reachable_residual(&self, source: i64) -> AHashSet<i64> {
let mut visited = AHashSet::new();
let mut queue = VecDeque::new();
visited.insert(source);
queue.push_back(source);
while let Some(node) = queue.pop_front() {
for edge in self.neighbors(node) {
if edge.residual() > 0 && visited.insert(edge.to) {
queue.push_back(edge.to);
}
}
}
visited
}
fn find_cut_edges(&self, source_side: &AHashSet<i64>) -> Vec<(i64, i64)> {
let mut cut_edges = Vec::new();
for &from in source_side {
for edge in self.neighbors(from) {
if !source_side.contains(&edge.to) && edge.residual() == 0 {
cut_edges.push((from, edge.to));
}
}
}
cut_edges
}
}
fn edmonds_karp(mut network: FlowNetwork, source: i64, sink: i64) -> (usize, FlowNetwork) {
let mut max_flow = 0;
while let Some(path) = bfs_augmenting_path(&network, source, sink) {
let bottleneck = find_bottleneck(&network, &path);
augment_flow(&mut network, &path, bottleneck);
max_flow += bottleneck;
}
(max_flow, network)
}
fn bfs_augmenting_path(network: &FlowNetwork, source: i64, sink: i64) -> Option<Vec<i64>> {
let mut parent: HashMap<i64, (i64, usize)> = HashMap::new();
let mut queue = VecDeque::new();
queue.push_back(source);
parent.insert(source, (source, 0));
while let Some(node) = queue.pop_front() {
if node == sink {
let mut path = vec![sink];
let mut current = sink;
while current != source {
let (prev_node, _edge_idx) = *parent.get(¤t)?;
path.push(prev_node);
current = prev_node;
}
path.reverse();
return Some(path);
}
for (edge_idx, edge) in network.neighbors(node).iter().enumerate() {
if edge.residual() > 0 && !parent.contains_key(&edge.to) {
parent.insert(edge.to, (node, edge_idx));
queue.push_back(edge.to);
}
}
}
None }
fn find_bottleneck(network: &FlowNetwork, path: &[i64]) -> usize {
let mut bottleneck = usize::MAX;
for i in 0..path.len().saturating_sub(1) {
let from = path[i];
let to = path[i + 1];
for edge in network.neighbors(from) {
if edge.to == to {
bottleneck = bottleneck.min(edge.residual());
break;
}
}
}
bottleneck
}
fn augment_flow(network: &mut FlowNetwork, path: &[i64], amount: usize) {
for i in 0..path.len().saturating_sub(1) {
let from = path[i];
let to = path[i + 1];
if let Some(forward_edges) = network.adjacency.get_mut(&from) {
for edge in forward_edges.iter_mut() {
if edge.to == to {
edge.flow += amount;
break;
}
}
}
if let Some(reverse_edges) = network.adjacency.get_mut(&to) {
for edge in reverse_edges.iter_mut() {
if edge.to == from {
edge.flow = edge.flow.saturating_sub(amount);
break;
}
}
}
}
}
fn build_flow_network(graph: &SqliteGraph, source: i64, _sink: i64) -> FlowNetwork {
let mut network = FlowNetwork::new();
let mut nodes_to_visit = vec![source];
let mut visited = AHashSet::new();
visited.insert(source);
while let Some(node) = nodes_to_visit.pop() {
if let Ok(neighbors) = graph.fetch_outgoing(node) {
for &neighbor in &neighbors {
network.add_edge(node, neighbor, 1);
if visited.insert(neighbor) {
nodes_to_visit.push(neighbor);
}
}
}
}
network
}
pub fn min_st_cut(
graph: &SqliteGraph,
source: i64,
sink: i64,
) -> Result<MinCutResult, SqliteGraphError> {
if source == sink {
return Ok(MinCutResult {
source_side: {
let mut set = AHashSet::new();
set.insert(source);
set
},
sink_side: AHashSet::new(),
cut_edges: vec![],
cut_size: 0,
});
}
let network = build_flow_network(graph, source, sink);
if network.nodes().contains(&source) && !network.nodes().contains(&sink) {
return Ok(MinCutResult {
source_side: network.nodes(),
sink_side: AHashSet::new(),
cut_edges: vec![],
cut_size: 0,
});
}
let (max_flow, residual_network) = edmonds_karp(network, source, sink);
let source_side = residual_network.reachable_residual(source);
let all_nodes = residual_network.nodes();
let sink_side = all_nodes.difference(&source_side).copied().collect();
let cut_edges = residual_network.find_cut_edges(&source_side);
Ok(MinCutResult {
source_side,
sink_side,
cut_edges,
cut_size: max_flow,
})
}
pub fn min_st_cut_with_progress<F>(
graph: &SqliteGraph,
source: i64,
sink: i64,
progress: &F,
) -> Result<MinCutResult, SqliteGraphError>
where
F: ProgressCallback,
{
if source == sink {
return Ok(MinCutResult {
source_side: {
let mut set = AHashSet::new();
set.insert(source);
set
},
sink_side: AHashSet::new(),
cut_edges: vec![],
cut_size: 0,
});
}
let network = build_flow_network(graph, source, sink);
if network.nodes().contains(&source) && !network.nodes().contains(&sink) {
return Ok(MinCutResult {
source_side: network.nodes(),
sink_side: AHashSet::new(),
cut_edges: vec![],
cut_size: 0,
});
}
let mut current_network = network;
let mut max_flow = 0;
let mut iteration = 0;
while let Some(path) = bfs_augmenting_path(¤t_network, source, sink) {
iteration += 1;
let bottleneck = find_bottleneck(¤t_network, &path);
augment_flow(&mut current_network, &path, bottleneck);
max_flow += bottleneck;
progress.on_progress(
iteration,
None,
&format!(
"Min cut: iteration {}, flow so far: {}",
iteration, max_flow
),
);
}
progress.on_complete();
let source_side = current_network.reachable_residual(source);
let all_nodes = current_network.nodes();
let sink_side = all_nodes.difference(&source_side).copied().collect();
let cut_edges = current_network.find_cut_edges(&source_side);
Ok(MinCutResult {
source_side,
sink_side,
cut_edges,
cut_size: max_flow,
})
}
struct VertexSplitTransform {
source: i64,
sink: i64,
}
impl VertexSplitTransform {
fn new(source: i64, sink: i64) -> Self {
Self { source, sink }
}
fn node_in(&self, x: i64) -> i64 {
if x == self.source || x == self.sink {
x } else {
x * 2
}
}
fn node_out(&self, x: i64) -> i64 {
if x == self.source || x == self.sink {
x } else {
x * 2 + 1
}
}
fn to_original(&self, node_id: i64) -> i64 {
if node_id == self.source || node_id == self.sink {
node_id
} else if node_id % 2 == 0 {
node_id / 2 } else {
(node_id - 1) / 2 }
}
fn is_internal_edge(&self, from: i64, to: i64) -> Option<i64> {
if from % 2 == 0 && to == from + 1 {
let original = from / 2;
if original != self.source && original != self.sink {
return Some(original);
}
}
if from % 2 == 1 && to == from - 1 {
let original = (from - 1) / 2;
if original != self.source && original != self.sink {
return Some(original);
}
}
None
}
}
fn build_vertex_split_network(
graph: &SqliteGraph,
source: i64,
sink: i64,
) -> (FlowNetwork, VertexSplitTransform) {
let transform = VertexSplitTransform::new(source, sink);
let mut network = FlowNetwork::new();
let mut nodes_to_visit = vec![source];
let mut visited = AHashSet::new();
visited.insert(source);
while let Some(node) = nodes_to_visit.pop() {
if let Ok(neighbors) = graph.fetch_outgoing(node) {
for &neighbor in &neighbors {
let node_out = transform.node_out(node);
let neighbor_in = transform.node_in(neighbor);
network.add_edge(node_out, neighbor_in, 1);
let neighbor_out = transform.node_out(neighbor);
if neighbor != source && neighbor != sink {
network.add_edge(neighbor_in, neighbor_out, 1);
}
let node_in = transform.node_in(node);
if node != source && node != sink {
network.add_edge(node_in, node_out, 1);
}
if visited.insert(neighbor) {
nodes_to_visit.push(neighbor);
}
}
}
}
let source_in = transform.node_in(source);
let source_out = transform.node_out(source);
if source_in != source_out {
network.add_edge(source_in, source_out, 1);
}
(network, transform)
}
pub fn min_vertex_cut(
graph: &SqliteGraph,
source: i64,
target: i64,
) -> Result<MinVertexCutResult, SqliteGraphError> {
if source == target {
return Ok(MinVertexCutResult {
separator: AHashSet::new(),
source_side: {
let mut set = AHashSet::new();
set.insert(source);
set
},
sink_side: AHashSet::new(),
cut_size: 0,
});
}
let (network, transform) = build_vertex_split_network(graph, source, target);
let source_out = transform.node_out(source);
let target_in = transform.node_in(target);
if !network.nodes().contains(&target_in) {
return Ok(MinVertexCutResult {
separator: AHashSet::new(),
source_side: {
let mut set = AHashSet::new();
set.insert(source);
set
},
sink_side: AHashSet::new(),
cut_size: 0,
});
}
let (_max_flow, residual_network) = edmonds_karp(network, source_out, target_in);
let mut separator = AHashSet::new();
for node in residual_network.nodes() {
for edge in residual_network.neighbors(node) {
if let Some(original) = transform.is_internal_edge(node, edge.to) {
if edge.residual() == 0 {
separator.insert(original);
}
}
}
}
let source_side_transformed = residual_network.reachable_residual(source_out);
let mut source_side = AHashSet::new();
for node in source_side_transformed {
source_side.insert(transform.to_original(node));
}
let all_nodes_transformed = residual_network.nodes();
let mut sink_side = AHashSet::new();
for node in all_nodes_transformed {
let original = transform.to_original(node);
if !source_side.contains(&original) {
sink_side.insert(original);
}
}
Ok(MinVertexCutResult {
separator: separator.clone(),
source_side,
sink_side,
cut_size: separator.len(),
})
}
pub fn min_vertex_cut_with_progress<F>(
graph: &SqliteGraph,
source: i64,
target: i64,
progress: &F,
) -> Result<MinVertexCutResult, SqliteGraphError>
where
F: ProgressCallback,
{
if source == target {
return Ok(MinVertexCutResult {
separator: AHashSet::new(),
source_side: {
let mut set = AHashSet::new();
set.insert(source);
set
},
sink_side: AHashSet::new(),
cut_size: 0,
});
}
let (network, transform) = build_vertex_split_network(graph, source, target);
let source_out = transform.node_out(source);
let target_in = transform.node_in(target);
if !network.nodes().contains(&target_in) {
return Ok(MinVertexCutResult {
separator: AHashSet::new(),
source_side: {
let mut set = AHashSet::new();
set.insert(source);
set
},
sink_side: AHashSet::new(),
cut_size: 0,
});
}
let mut current_network = network;
let mut max_flow = 0;
let mut iteration = 0;
while let Some(path) = bfs_augmenting_path(¤t_network, source_out, target_in) {
iteration += 1;
let bottleneck = find_bottleneck(¤t_network, &path);
augment_flow(&mut current_network, &path, bottleneck);
max_flow += bottleneck;
progress.on_progress(
iteration,
None,
&format!(
"Vertex cut: iteration {}, flow so far: {}",
iteration, max_flow
),
);
}
progress.on_complete();
let mut separator = AHashSet::new();
for node in current_network.nodes() {
for edge in current_network.neighbors(node) {
if let Some(original) = transform.is_internal_edge(node, edge.to) {
if edge.residual() == 0 {
separator.insert(original);
}
}
}
}
let source_side_transformed = current_network.reachable_residual(source_out);
let mut source_side = AHashSet::new();
for node in source_side_transformed {
source_side.insert(transform.to_original(node));
}
let all_nodes_transformed = current_network.nodes();
let mut sink_side = AHashSet::new();
for node in all_nodes_transformed {
let original = transform.to_original(node);
if !source_side.contains(&original) {
sink_side.insert(original);
}
}
Ok(MinVertexCutResult {
separator,
source_side,
sink_side,
cut_size: max_flow,
})
}
fn compute_cut_edges(
graph: &SqliteGraph,
node_to_partition: &HashMap<i64, usize>,
) -> Vec<(i64, i64)> {
let mut cut_edges = Vec::new();
let nodes_to_check: Vec<i64> = if let Ok(all_ids) = graph.all_entity_ids() {
all_ids
} else {
return cut_edges;
};
for &from_node in &nodes_to_check {
if let Ok(neighbors) = graph.fetch_outgoing(from_node) {
for &to_node in &neighbors {
if let (Some(&from_partition), Some(&to_partition)) = (
node_to_partition.get(&from_node),
node_to_partition.get(&to_node),
) {
if from_partition != to_partition {
cut_edges.push((from_node, to_node));
}
}
}
}
}
cut_edges
}
pub fn partition_bfs_level(
graph: &SqliteGraph,
seeds: Vec<i64>,
k: usize,
) -> Result<PartitionResult, SqliteGraphError> {
if k < 2 {
return Ok(PartitionResult {
partitions: vec![AHashSet::new()],
cut_edges: vec![],
node_to_partition: HashMap::new(),
});
}
let all_nodes: AHashSet<i64> = graph.all_entity_ids()?.into_iter().collect();
if all_nodes.is_empty() {
return Ok(PartitionResult {
partitions: vec![AHashSet::new(); k],
cut_edges: vec![],
node_to_partition: HashMap::new(),
});
}
let mut effective_seeds = seeds;
if effective_seeds.is_empty() {
let mut sorted_nodes: Vec<i64> = all_nodes.iter().copied().collect();
sorted_nodes.sort();
effective_seeds = sorted_nodes.into_iter().take(k).collect();
}
effective_seeds.truncate(k.min(effective_seeds.len()));
let num_partitions = k.max(effective_seeds.len());
let mut partitions: Vec<AHashSet<i64>> = (0..num_partitions).map(|_| AHashSet::new()).collect();
let mut node_to_partition: HashMap<i64, usize> = HashMap::new();
let mut queue: VecDeque<(i64, usize, usize)> = VecDeque::new();
let mut visited: AHashSet<i64> = AHashSet::new();
for (seed_idx, &seed) in effective_seeds.iter().enumerate() {
if all_nodes.contains(&seed) {
partitions[seed_idx].insert(seed);
node_to_partition.insert(seed, seed_idx);
visited.insert(seed);
queue.push_back((seed, 0, seed_idx));
}
}
while let Some((node, _level, seed_idx)) = queue.pop_front() {
if let Ok(neighbors) = graph.fetch_outgoing(node) {
for &neighbor in &neighbors {
if visited.insert(neighbor) {
partitions[seed_idx].insert(neighbor);
node_to_partition.insert(neighbor, seed_idx);
queue.push_back((neighbor, 0, seed_idx));
}
}
}
}
while partitions.len() < k {
partitions.push(AHashSet::new());
}
let cut_edges = compute_cut_edges(graph, &node_to_partition);
Ok(PartitionResult {
partitions,
cut_edges,
node_to_partition,
})
}
pub fn partition_greedy(
graph: &SqliteGraph,
initial_partition: Option<Vec<AHashSet<i64>>>,
max_iterations: usize,
) -> Result<PartitionResult, SqliteGraphError> {
let all_nodes: AHashSet<i64> = graph.all_entity_ids()?.into_iter().collect();
if all_nodes.is_empty() {
return Ok(PartitionResult {
partitions: vec![AHashSet::new(), AHashSet::new()],
cut_edges: vec![],
node_to_partition: HashMap::new(),
});
}
let (mut partitions, mut node_to_partition) = if let Some(init) = initial_partition {
if init.len() < 2 {
let init_result = partition_bfs_level(graph, vec![], 2)?;
(init_result.partitions, init_result.node_to_partition)
} else {
let mut mapping = HashMap::new();
for (pidx, partition) in init.iter().enumerate() {
for &node in partition {
mapping.insert(node, pidx);
}
}
(init, mapping)
}
} else {
let init_result = partition_bfs_level(graph, vec![], 2)?;
(init_result.partitions, init_result.node_to_partition)
};
if partitions.len() != 2 {
partitions.resize(2, AHashSet::new());
}
let initial_cut_size = compute_cut_edges(graph, &node_to_partition).len();
let mut best_partitions = partitions.clone();
let mut best_mapping = node_to_partition.clone();
let mut best_cut_size = initial_cut_size;
for _iteration in 0..max_iterations {
let mut improvement_found = false;
let mut best_move: Option<(i64, usize, i64)> = None; let mut best_gain: i64 = 0;
for &node in all_nodes.iter() {
if let Some(&from_partition) = node_to_partition.get(&node) {
let to_partition = 1 - from_partition;
let mut edges_to_other = 0i64;
let mut edges_within = 0i64;
if let Ok(neighbors) = graph.fetch_outgoing(node) {
for &neighbor in &neighbors {
if let Some(&neighbor_partition) = node_to_partition.get(&neighbor) {
if neighbor_partition == to_partition {
edges_to_other += 1;
} else if neighbor_partition == from_partition && neighbor != node {
edges_within += 1;
}
}
}
}
let gain = edges_to_other - edges_within;
if gain > best_gain {
best_gain = gain;
best_move = Some((node, from_partition, gain));
improvement_found = true;
}
}
}
if !improvement_found || best_gain <= 0 {
break; }
if let Some((node, from_partition, _gain)) = best_move {
let to_partition = 1 - from_partition;
partitions[from_partition].remove(&node);
partitions[to_partition].insert(node);
node_to_partition.insert(node, to_partition);
let current_cut_size = compute_cut_edges(graph, &node_to_partition).len();
if current_cut_size < best_cut_size {
best_cut_size = current_cut_size;
best_partitions = partitions.clone();
best_mapping = node_to_partition.clone();
}
}
}
let cut_edges = compute_cut_edges(graph, &best_mapping);
Ok(PartitionResult {
partitions: best_partitions,
cut_edges,
node_to_partition: best_mapping,
})
}
fn select_seeds_by_degree(
graph: &SqliteGraph,
k: usize,
available_nodes: &AHashSet<i64>,
) -> Vec<i64> {
let mut node_degrees: Vec<(i64, usize)> = Vec::new();
for &node in available_nodes {
if let Ok(outgoing) = graph.fetch_outgoing(node) {
let degree = outgoing.len();
node_degrees.push((node, degree));
}
}
node_degrees.sort_by_key(|b| std::cmp::Reverse(b.1));
node_degrees.truncate(k);
node_degrees.into_iter().map(|(node, _)| node).collect()
}
fn shortest_distance_to_targets(graph: &SqliteGraph, from: i64, targets: &AHashSet<i64>) -> usize {
if targets.contains(&from) {
return 0;
}
let mut visited: AHashSet<i64> = AHashSet::new();
let mut queue: VecDeque<(i64, usize)> = VecDeque::new();
visited.insert(from);
queue.push_back((from, 0));
while let Some((node, dist)) = queue.pop_front() {
if let Ok(neighbors) = graph.fetch_outgoing(node) {
for &neighbor in &neighbors {
if targets.contains(&neighbor) {
return dist + 1;
}
if visited.insert(neighbor) {
queue.push_back((neighbor, dist + 1));
}
}
}
}
usize::MAX }
pub fn partition_kway(
graph: &SqliteGraph,
config: &PartitionConfig,
) -> Result<PartitionResult, SqliteGraphError> {
if config.k < 2 {
return Err(SqliteGraphError::InvalidInput(
"k must be at least 2 for partitioning".to_string(),
));
}
let all_nodes: AHashSet<i64> = graph.all_entity_ids()?.into_iter().collect();
if all_nodes.is_empty() {
return Ok(PartitionResult {
partitions: vec![AHashSet::new(); config.k],
cut_edges: vec![],
node_to_partition: HashMap::new(),
});
}
let effective_k = config.k.min(all_nodes.len());
let mut partitions: Vec<AHashSet<i64>> = (0..effective_k).map(|_| AHashSet::new()).collect();
let mut node_to_partition: HashMap<i64, usize> = HashMap::new();
let seeds = if let Some(ref provided_seeds) = config.seeds {
provided_seeds.clone()
} else {
select_seeds_by_degree(graph, effective_k, &all_nodes)
};
let mut effective_seeds = seeds;
effective_seeds.truncate(effective_k);
while effective_seeds.len() < effective_k {
for &node in &all_nodes {
if !effective_seeds.contains(&node) {
effective_seeds.push(node);
if effective_seeds.len() >= effective_k {
break;
}
}
}
}
let target_size = (all_nodes.len() / effective_k).max(1);
let max_allowed = if config.max_size == usize::MAX {
((target_size as f64) * (1.0 + config.max_imbalance)) as usize
} else {
config.max_size.min(all_nodes.len())
};
let mut queue: VecDeque<(i64, usize)> = VecDeque::new(); let mut unassigned: AHashSet<i64> = AHashSet::new();
for (pidx, &seed) in effective_seeds.iter().enumerate() {
if all_nodes.contains(&seed) {
partitions[pidx].insert(seed);
node_to_partition.insert(seed, pidx);
queue.push_back((seed, pidx));
}
}
for &node in &all_nodes {
if !node_to_partition.contains_key(&node) {
unassigned.insert(node);
}
}
while let Some((node, pidx)) = queue.pop_front() {
if partitions[pidx].len() >= max_allowed {
continue;
}
if let Ok(neighbors) = graph.fetch_outgoing(node) {
for &neighbor in &neighbors {
if unassigned.remove(&neighbor) {
partitions[pidx].insert(neighbor);
node_to_partition.insert(neighbor, pidx);
queue.push_back((neighbor, pidx));
}
}
}
}
for &node in &unassigned {
let mut best_partition = 0;
let mut best_distance = usize::MAX;
#[allow(clippy::needless_range_loop)]
for pidx in 0..effective_k {
let target_nodes: AHashSet<i64> = partitions[pidx].iter().copied().collect();
if target_nodes.is_empty() {
continue;
}
let distance = shortest_distance_to_targets(graph, node, &target_nodes);
if distance < best_distance {
best_distance = distance;
best_partition = pidx;
}
}
partitions[best_partition].insert(node);
node_to_partition.insert(node, best_partition);
}
while partitions.len() < config.k {
partitions.push(AHashSet::new());
}
let cut_edges = compute_cut_edges(graph, &node_to_partition);
Ok(PartitionResult {
partitions,
cut_edges,
node_to_partition,
})
}
pub fn partition_kway_with_progress<F>(
graph: &SqliteGraph,
config: &PartitionConfig,
progress: &F,
) -> Result<PartitionResult, SqliteGraphError>
where
F: ProgressCallback,
{
if config.k < 2 {
return Err(SqliteGraphError::InvalidInput(
"k must be at least 2 for partitioning".to_string(),
));
}
let all_nodes: AHashSet<i64> = graph.all_entity_ids()?.into_iter().collect();
let total_nodes = all_nodes.len();
if all_nodes.is_empty() {
progress.on_complete();
return Ok(PartitionResult {
partitions: vec![AHashSet::new(); config.k],
cut_edges: vec![],
node_to_partition: HashMap::new(),
});
}
let effective_k = config.k.min(all_nodes.len());
let mut partitions: Vec<AHashSet<i64>> = (0..effective_k).map(|_| AHashSet::new()).collect();
let mut node_to_partition: HashMap<i64, usize> = HashMap::new();
let seeds = if let Some(ref provided_seeds) = config.seeds {
provided_seeds.clone()
} else {
select_seeds_by_degree(graph, effective_k, &all_nodes)
};
let mut effective_seeds = seeds;
effective_seeds.truncate(effective_k);
while effective_seeds.len() < effective_k {
for &node in &all_nodes {
if !effective_seeds.contains(&node) {
effective_seeds.push(node);
if effective_seeds.len() >= effective_k {
break;
}
}
}
}
let target_size = (all_nodes.len() / effective_k).max(1);
let max_allowed = if config.max_size == usize::MAX {
((target_size as f64) * (1.0 + config.max_imbalance)) as usize
} else {
config.max_size.min(all_nodes.len())
};
let mut queue: VecDeque<(i64, usize)> = VecDeque::new();
let mut unassigned: AHashSet<i64> = AHashSet::new();
let mut assigned_count = 0;
for (pidx, &seed) in effective_seeds.iter().enumerate() {
if all_nodes.contains(&seed) {
partitions[pidx].insert(seed);
node_to_partition.insert(seed, pidx);
assigned_count += 1;
queue.push_back((seed, pidx));
}
}
for &node in &all_nodes {
if !node_to_partition.contains_key(&node) {
unassigned.insert(node);
}
}
while let Some((node, pidx)) = queue.pop_front() {
if partitions[pidx].len() >= max_allowed {
continue;
}
if let Ok(neighbors) = graph.fetch_outgoing(node) {
for &neighbor in &neighbors {
if unassigned.remove(&neighbor) {
partitions[pidx].insert(neighbor);
node_to_partition.insert(neighbor, pidx);
assigned_count += 1;
queue.push_back((neighbor, pidx));
if assigned_count % 10 == 0 {
progress.on_progress(
assigned_count,
Some(total_nodes),
&format!(
"K-way partition: assigned {}/{} nodes",
assigned_count, total_nodes
),
);
}
}
}
}
}
for &node in &unassigned {
let mut best_partition = 0;
let mut best_distance = usize::MAX;
#[allow(clippy::needless_range_loop)]
for pidx in 0..effective_k {
let target_nodes: AHashSet<i64> = partitions[pidx].iter().copied().collect();
if target_nodes.is_empty() {
continue;
}
let distance = shortest_distance_to_targets(graph, node, &target_nodes);
if distance < best_distance {
best_distance = distance;
best_partition = pidx;
}
}
partitions[best_partition].insert(node);
node_to_partition.insert(node, best_partition);
assigned_count += 1;
}
let _ = assigned_count; progress.on_complete();
while partitions.len() < config.k {
partitions.push(AHashSet::new());
}
let cut_edges = compute_cut_edges(graph, &node_to_partition);
Ok(PartitionResult {
partitions,
cut_edges,
node_to_partition,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{GraphEdge, GraphEntity};
fn create_linear_chain() -> (SqliteGraph, i64, i64) {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
(graph, entity_ids[0], entity_ids[3])
}
fn create_diamond() -> (SqliteGraph, i64, i64) {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (0, 2), (1, 3), (2, 3)];
for (from_idx, to_idx) in edges {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
(graph, entity_ids[0], entity_ids[3])
}
fn create_parallel_paths() -> (SqliteGraph, i64, i64) {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..5 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edges = vec![(0, 1), (1, 4), (0, 2), (2, 4), (0, 3), (3, 4)];
for (from_idx, to_idx) in edges {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[from_idx],
to_id: entity_ids[to_idx],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
(graph, entity_ids[0], entity_ids[4])
}
fn create_single_edge() -> (SqliteGraph, i64, i64) {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..2 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edge = GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
(graph, entity_ids[0], entity_ids[1])
}
fn create_disconnected() -> (SqliteGraph, i64, i64) {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..4 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let edge1 = GraphEdge {
id: 0,
from_id: entity_ids[0],
to_id: entity_ids[1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge1).expect("Failed to insert edge");
let edge2 = GraphEdge {
id: 0,
from_id: entity_ids[2],
to_id: entity_ids[3],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge2).expect("Failed to insert edge");
(graph, entity_ids[0], entity_ids[3])
}
#[test]
fn test_min_st_cut_linear_chain() {
let (graph, source, sink) = create_linear_chain();
let result = min_st_cut(&graph, source, sink).expect("Failed to compute min cut");
assert_eq!(result.cut_size, 1, "Linear chain should have cut size 1");
assert_eq!(result.cut_edges.len(), 1, "Should have 1 cut edge");
assert!(
result.source_side.contains(&source),
"Source side should contain source"
);
assert!(
result.sink_side.contains(&sink),
"Sink side should contain sink"
);
}
#[test]
fn test_min_st_cut_diamond() {
let (graph, source, sink) = create_diamond();
let result = min_st_cut(&graph, source, sink).expect("Failed to compute min cut");
assert_eq!(result.cut_size, 2, "Diamond should have cut size 2");
assert_eq!(result.cut_edges.len(), 2, "Should have 2 cut edges");
}
#[test]
fn test_min_st_cut_parallel_paths() {
let (graph, source, sink) = create_parallel_paths();
let result = min_st_cut(&graph, source, sink).expect("Failed to compute min cut");
assert_eq!(result.cut_size, 3, "Parallel paths should have cut size 3");
assert_eq!(result.cut_edges.len(), 3, "Should have 3 cut edges");
}
#[test]
fn test_min_st_cut_single_edge() {
let (graph, source, sink) = create_single_edge();
let result = min_st_cut(&graph, source, sink).expect("Failed to compute min cut");
assert_eq!(result.cut_size, 1, "Single edge should have cut size 1");
assert_eq!(result.cut_edges.len(), 1, "Should have 1 cut edge");
assert_eq!(
result.cut_edges[0],
(source, sink),
"Cut edge should be (source, sink)"
);
}
#[test]
fn test_min_st_cut_source_equals_target() {
let (graph, source, _) = create_single_edge();
let result = min_st_cut(&graph, source, source).expect("Failed to compute min cut");
assert_eq!(result.cut_size, 0, "Source==target should have cut size 0");
assert!(result.cut_edges.is_empty(), "Cut edges should be empty");
assert!(
result.source_side.contains(&source),
"Source side contains source"
);
assert!(result.sink_side.is_empty(), "Sink side should be empty");
}
#[test]
fn test_min_st_cut_with_progress_matches() {
use crate::progress::NoProgress;
let (graph, source, sink) = create_diamond();
let progress = NoProgress;
let result_with =
min_st_cut_with_progress(&graph, source, sink, &progress).expect("Failed");
let result_without = min_st_cut(&graph, source, sink).expect("Failed");
assert_eq!(
result_with.cut_size, result_without.cut_size,
"Cut size should match"
);
assert_eq!(
result_with.cut_edges.len(),
result_without.cut_edges.len(),
"Cut edges count should match"
);
}
#[test]
fn test_min_vertex_cut_bridge_node() {
let (graph, source, sink) = create_linear_chain();
let result = min_vertex_cut(&graph, source, sink).expect("Failed to compute vertex cut");
assert_eq!(
result.cut_size, 2,
"Linear chain should have vertex cut size 2 (both intermediate nodes)"
);
assert_eq!(
result.separator.len(),
2,
"Should have 2 separator vertices"
);
}
#[test]
fn test_min_vertex_cut_two_parallel_paths() {
let (graph, source, sink) = create_diamond();
let result = min_vertex_cut(&graph, source, sink).expect("Failed to compute vertex cut");
assert_eq!(
result.cut_size, 2,
"Two parallel paths should have vertex cut size 2"
);
assert_eq!(
result.separator.len(),
2,
"Should have 2 separator vertices"
);
}
#[test]
fn test_min_vertex_cut_direct_edge() {
let (graph, source, sink) = create_single_edge();
eprintln!("Direct edge test: source={}, sink={}", source, sink);
let result = min_vertex_cut(&graph, source, sink).expect("Failed to compute vertex cut");
eprintln!(
"cut_size={}, separator={:?}",
result.cut_size, result.separator
);
assert_eq!(
result.cut_size, 0,
"Direct edge should have vertex cut size 0"
);
assert!(
result.separator.is_empty(),
"Separator should be empty for direct edge"
);
}
#[test]
fn test_min_vertex_cut_source_equals_target() {
let (graph, source, _) = create_single_edge();
let result = min_vertex_cut(&graph, source, source).expect("Failed to compute vertex cut");
assert_eq!(result.cut_size, 0, "Source==target should have cut size 0");
assert!(result.separator.is_empty(), "Separator should be empty");
assert!(
result.source_side.contains(&source),
"Source side contains source"
);
}
#[test]
fn test_min_vertex_cut_with_progress_matches() {
use crate::progress::NoProgress;
let (graph, source, sink) = create_diamond();
let progress = NoProgress;
let result_with =
min_vertex_cut_with_progress(&graph, source, sink, &progress).expect("Failed");
let result_without = min_vertex_cut(&graph, source, sink).expect("Failed");
assert_eq!(
result_with.cut_size, result_without.cut_size,
"Cut size should match"
);
assert_eq!(
result_with.separator.len(),
result_without.separator.len(),
"Separator size should match"
);
}
#[test]
fn test_min_vertex_cut_three_parallel_paths() {
let (graph, source, sink) = create_parallel_paths();
let result = min_vertex_cut(&graph, source, sink).expect("Failed to compute vertex cut");
assert_eq!(
result.cut_size, 3,
"Three parallel paths should have vertex cut size 3"
);
assert_eq!(
result.separator.len(),
3,
"Should have 3 separator vertices"
);
}
fn create_path_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..5 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_star_graph(leaves: usize) -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let center_entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: "center".to_string(),
file_path: Some("center.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(¢er_entity)
.expect("Failed to insert entity");
for i in 0..leaves {
let leaf_entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("leaf_{}", i),
file_path: Some(format!("leaf_{}.rs", i)),
data: serde_json::json!({}),
};
graph
.insert_entity(&leaf_entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let center_id = entity_ids[0];
for i in 1..entity_ids.len() {
let edge = GraphEdge {
id: 0,
from_id: center_id,
to_id: entity_ids[i],
edge_type: "edge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
graph
}
fn create_binary_tree(height: usize) -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let num_nodes = 2_usize.pow(height as u32 + 1) - 1;
for i in 0..num_nodes {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..num_nodes / 2 {
let left_child = 2 * i + 1;
let right_child = 2 * i + 2;
if left_child < num_nodes {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[left_child],
edge_type: "left".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
if right_child < num_nodes {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[right_child],
edge_type: "right".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
}
graph
}
fn create_two_cliques() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..3 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("c1_{}", i),
file_path: Some(format!("c1_{}.rs", i)),
data: serde_json::json!({"clique": 1}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
for i in 3..6 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("c2_{}", i),
file_path: Some(format!("c2_{}.rs", i)),
data: serde_json::json!({"clique": 2}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..3 {
for j in (i + 1)..3 {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[j],
edge_type: "intra".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
}
for i in 3..6 {
for j in (i + 1)..6 {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[j],
edge_type: "intra".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
}
let bridge = GraphEdge {
id: 0,
from_id: entity_ids[1],
to_id: entity_ids[4],
edge_type: "bridge".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&bridge).expect("Failed to insert edge");
graph
}
#[test]
fn test_partition_bfs_level_path_graph() {
let graph = create_path_graph();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = partition_bfs_level(&graph, vec![entity_ids[0], entity_ids[4]], 2)
.expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
assert_eq!(
result.partitions[0].len() + result.partitions[1].len(),
5,
"All nodes should be assigned"
);
assert!(result.cut_edges.len() <= 2, "Cut edges should be minimal");
}
#[test]
fn test_partition_bfs_level_star_graph() {
let graph = create_star_graph(4);
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = partition_bfs_level(&graph, vec![entity_ids[0], entity_ids[2]], 2)
.expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
assert_eq!(
result.partitions[0].len() + result.partitions[1].len(),
5,
"All nodes should be assigned"
);
}
#[test]
fn test_partition_bfs_level_binary_tree() {
let graph = create_binary_tree(2);
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let result = partition_bfs_level(&graph, vec![entity_ids[0], entity_ids[6]], 2)
.expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
assert_eq!(
result.partitions[0].len() + result.partitions[1].len(),
7,
"All nodes should be assigned"
);
}
#[test]
fn test_partition_bfs_level_disconnected() {
let (graph, node_a, node_b) = create_disconnected();
let result =
partition_bfs_level(&graph, vec![node_a, node_b], 2).expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
assert!(
result.partitions.iter().all(|p| p.len() > 0),
"Each partition should have at least one node"
);
}
#[test]
fn test_partition_bfs_level_empty_seeds() {
let graph = create_path_graph();
let result = partition_bfs_level(&graph, vec![], 2).expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
}
#[test]
fn test_partition_greedy_two_cliques() {
let graph = create_two_cliques();
let result = partition_greedy(&graph, None, 100).expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
let total_assigned = result.partitions[0].len() + result.partitions[1].len();
assert!(
total_assigned >= 3,
"Should assign at least some nodes, got {}",
total_assigned
);
}
#[test]
fn test_partition_greedy_cut_size_decreases() {
let graph = create_binary_tree(2);
let initial = partition_bfs_level(&graph, vec![], 2).expect("Failed");
let initial_cut_size = initial.cut_edges.len();
let result = partition_greedy(&graph, None, 100).expect("Failed to partition");
assert!(
result.cut_edges.len() <= initial_cut_size,
"Greedy should not increase cut size"
);
}
#[test]
fn test_partition_greedy_with_initial_partition() {
let graph = create_path_graph();
let initial_partition = vec![
graph
.all_entity_ids()
.unwrap()
.into_iter()
.take(2)
.collect(),
graph
.all_entity_ids()
.unwrap()
.into_iter()
.skip(2)
.collect(),
];
let result =
partition_greedy(&graph, Some(initial_partition), 10).expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
}
#[test]
fn test_partition_kway_balanced() {
let _graph = create_path_graph();
let large_graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..10 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({"index": i}),
};
large_graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let entity_ids: Vec<i64> = large_graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..entity_ids.len().saturating_sub(1) {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
large_graph
.insert_edge(&edge)
.expect("Failed to insert edge");
}
let config = PartitionConfig {
k: 2,
max_size: 5,
max_imbalance: 0.1,
seeds: None,
};
let result = partition_kway(&large_graph, &config).expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
let total: usize = result.partitions.iter().map(|p| p.len()).sum();
assert_eq!(total, 10, "All 10 nodes should be assigned");
}
#[test]
fn test_partition_kway_three_way() {
let graph = create_path_graph();
let config = PartitionConfig {
k: 3,
max_size: 4,
max_imbalance: 0.5, seeds: None,
};
let result = partition_kway(&graph, &config).expect("Failed to partition");
assert_eq!(result.partitions.len(), 3, "Should have 3 partitions");
let total_assigned: usize = result.partitions.iter().map(|p| p.len()).sum();
assert_eq!(total_assigned, 5, "All 5 nodes should be assigned");
}
#[test]
fn test_partition_kway_with_isolated_node() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..3 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let isolated = GraphEntity {
id: 0,
kind: "node".to_string(),
name: "isolated".to_string(),
file_path: Some("isolated.rs".to_string()),
data: serde_json::json!({}),
};
graph
.insert_entity(&isolated)
.expect("Failed to insert entity");
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
for i in 0..2 {
let edge = GraphEdge {
id: 0,
from_id: entity_ids[i],
to_id: entity_ids[i + 1],
edge_type: "next".to_string(),
data: serde_json::json!({}),
};
graph.insert_edge(&edge).expect("Failed to insert edge");
}
let config = PartitionConfig {
k: 2,
max_size: usize::MAX,
max_imbalance: 0.1,
seeds: None,
};
let result = partition_kway(&graph, &config).expect("Failed to partition");
let total_assigned: usize = result.partitions.iter().map(|p| p.len()).sum();
assert_eq!(
total_assigned, 4,
"All nodes including isolated should be assigned"
);
}
#[test]
fn test_partition_kway_with_seeds() {
let graph = create_path_graph();
let entity_ids: Vec<i64> = graph.list_entity_ids().expect("Failed to get IDs");
let config = PartitionConfig {
k: 2,
max_size: usize::MAX,
max_imbalance: 0.1,
seeds: Some(vec![entity_ids[0], entity_ids[4]]),
};
let result = partition_kway(&graph, &config).expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have 2 partitions");
let p0 = result.node_to_partition.get(&entity_ids[0]);
let p4 = result.node_to_partition.get(&entity_ids[4]);
assert!(p0.is_some() && p4.is_some(), "All seeds should be assigned");
assert_ne!(p0, p4, "Seeds should be in different partitions");
}
#[test]
fn test_partition_kway_invalid_k() {
let graph = create_path_graph();
let config = PartitionConfig {
k: 1, ..Default::default()
};
let result = partition_kway(&graph, &config);
assert!(result.is_err(), "Should return error for k < 2");
}
#[test]
fn test_partition_kway_with_progress_matches() {
use crate::progress::NoProgress;
let graph = create_path_graph();
let config = PartitionConfig::default();
let progress = NoProgress;
let result_with = partition_kway_with_progress(&graph, &config, &progress).expect("Failed");
let result_without = partition_kway(&graph, &config).expect("Failed");
assert_eq!(
result_with.partitions.len(),
result_without.partitions.len(),
"Partition count should match"
);
let total_with: usize = result_with.partitions.iter().map(|p| p.len()).sum();
let total_without: usize = result_without.partitions.iter().map(|p| p.len()).sum();
assert_eq!(
total_with, total_without,
"Total assigned nodes should match"
);
}
#[test]
fn test_partition_result_consistency() {
let graph = create_binary_tree(2);
let result = partition_bfs_level(&graph, vec![], 3).expect("Failed to partition");
for (pidx, partition) in result.partitions.iter().enumerate() {
for &node in partition {
assert_eq!(
result.node_to_partition.get(&node),
Some(&pidx),
"Node {} should map to partition {}",
node,
pidx
);
}
}
}
#[test]
fn test_partition_empty_graph() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
let result = partition_bfs_level(&graph, vec![], 2).expect("Failed to partition");
assert_eq!(result.partitions.len(), 2, "Should have k partitions");
assert!(
result.partitions.iter().all(|p| p.is_empty()),
"All partitions should be empty"
);
assert!(result.cut_edges.is_empty(), "No cut edges for empty graph");
}
#[test]
fn test_partition_k_greater_than_nodes() {
let graph = SqliteGraph::open_in_memory().expect("Failed to create graph");
for i in 0..3 {
let entity = GraphEntity {
id: 0,
kind: "node".to_string(),
name: format!("node_{}", i),
file_path: Some(format!("node_{}.rs", i)),
data: serde_json::json!({}),
};
graph
.insert_entity(&entity)
.expect("Failed to insert entity");
}
let result = partition_bfs_level(&graph, vec![], 10).expect("Failed to partition");
assert_eq!(result.partitions.len(), 10, "Should have 10 partitions");
let non_empty_count = result.partitions.iter().filter(|p| !p.is_empty()).count();
assert_eq!(non_empty_count, 3, "Only 3 partitions should be non-empty");
}
}