use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct WeightedGraph {
edges: Vec<(usize, usize)>,
weights: Vec<f64>,
}
impl WeightedGraph {
pub fn new(edges: Vec<(usize, usize)>, weights: Vec<f64>) -> Self {
Self { edges, weights }
}
}
pub fn solve(graph: &WeightedGraph) -> (HashSet<usize>, f64) {
let mut adj_list: HashMap<usize, Vec<usize>> = HashMap::new();
let mut vertex_set = HashSet::new();
for &(u, v) in &graph.edges {
adj_list.entry(u).or_default().push(v);
adj_list.entry(v).or_default().push(u);
vertex_set.insert(u);
vertex_set.insert(v);
}
let mut weights = graph.weights.clone();
let solution = solve_recursive(&adj_list, &mut weights, &vertex_set, &graph.weights);
let total_weight: f64 = solution.iter().map(|&v| graph.weights[v]).sum();
(solution, total_weight)
}
fn solve_recursive(
adj_list: &HashMap<usize, Vec<usize>>,
weights: &mut Vec<f64>,
vertex_set: &HashSet<usize>,
original_weights: &[f64],
) -> HashSet<usize> {
let mut remaining_edge = None;
for &u in vertex_set {
if weights[u] <= 0.0 {
continue;
}
if let Some(neighbors) = adj_list.get(&u) {
for &v in neighbors {
if weights[v] > 0.0 {
remaining_edge = Some((u, v));
break;
}
}
}
if remaining_edge.is_some() {
break;
}
}
if remaining_edge.is_none() {
return HashSet::new();
}
let (u, v) = remaining_edge.unwrap();
let epsilon = weights[u].min(weights[v]);
weights[u] -= epsilon;
weights[v] -= epsilon;
let mut solution = solve_recursive(adj_list, weights, vertex_set, original_weights);
if !solution.contains(&u) && !solution.contains(&v) {
if original_weights[u] <= original_weights[v] {
solution.insert(u);
} else {
solution.insert(v);
}
}
solution
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_graph() {
let edges = vec![(0, 1), (1, 2)];
let weights = vec![1.0, 2.0, 1.0];
let graph = WeightedGraph::new(edges, weights);
let (cover, weight) = solve(&graph);
assert!(cover.len() <= 2);
for &(u, v) in &graph.edges {
assert!(cover.contains(&u) || cover.contains(&v));
}
let actual_weight: f64 = cover.iter().map(|&v| graph.weights[v]).sum();
assert_eq!(weight, actual_weight);
}
#[test]
fn test_star_graph() {
let edges = vec![(0, 1), (0, 2), (0, 3)];
let weights = vec![1.0, 2.0, 2.0, 2.0];
let graph = WeightedGraph::new(edges, weights);
let (cover, weight) = solve(&graph);
assert!(weight <= 2.0 * 1.0);
for &(u, v) in &graph.edges {
assert!(cover.contains(&u) || cover.contains(&v));
}
}
#[test]
fn test_empty_graph() {
let graph = WeightedGraph::new(Vec::new(), vec![1.0, 1.0]);
let (cover, weight) = solve(&graph);
assert!(cover.is_empty());
assert_eq!(weight, 0.0);
}
}