use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct SetCoverInstance {
sets: Vec<HashSet<usize>>,
weights: Vec<f64>,
universe: HashSet<usize>,
}
impl SetCoverInstance {
pub fn new(sets: Vec<HashSet<usize>>, weights: Vec<f64>, universe: HashSet<usize>) -> Self {
assert_eq!(sets.len(), weights.len(), "Each set must have a weight");
Self {
sets,
weights,
universe,
}
}
}
pub fn solve(instance: &SetCoverInstance) -> (Vec<usize>, f64) {
let mut selected_sets = Vec::new();
let mut uncovered: HashSet<_> = instance.universe.clone();
let mut duals: HashMap<usize, f64> = HashMap::new();
for &e in &instance.universe {
duals.insert(e, 0.0);
}
let mut slacks: Vec<f64> = instance.weights.clone();
let mut element_to_sets: HashMap<usize, Vec<usize>> = HashMap::new();
for (set_idx, set) in instance.sets.iter().enumerate() {
for &e in set {
element_to_sets.entry(e).or_default().push(set_idx);
}
}
while !uncovered.is_empty() {
let mut min_slack = f64::INFINITY;
let mut min_slack_set = 0;
for &e in &uncovered {
for &set_idx in element_to_sets.get(&e).unwrap_or(&Vec::new()) {
if slacks[set_idx] < min_slack {
min_slack = slacks[set_idx];
min_slack_set = set_idx;
}
}
}
for &e in &uncovered {
let increase = min_slack;
*duals.get_mut(&e).unwrap() += increase;
for &set_idx in element_to_sets.get(&e).unwrap_or(&Vec::new()) {
slacks[set_idx] -= increase;
}
}
selected_sets.push(min_slack_set);
for &e in &instance.sets[min_slack_set] {
uncovered.remove(&e);
}
}
let total_weight: f64 = selected_sets.iter().map(|&idx| instance.weights[idx]).sum();
(selected_sets, total_weight)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_set_cover() {
let mut sets = Vec::new();
sets.push([1, 2].iter().cloned().collect());
sets.push([2, 3].iter().cloned().collect());
sets.push([3, 4].iter().cloned().collect());
let weights = vec![1.0, 1.0, 1.0];
let universe: HashSet<_> = (1..=4).collect();
let instance = SetCoverInstance::new(sets, weights, universe);
let (solution, _weight) = solve(&instance);
let mut covered = HashSet::new();
for &idx in &solution {
covered.extend(&instance.sets[idx]);
}
assert_eq!(covered, instance.universe);
let max_freq = 2; assert!(_weight <= max_freq as f64 * 2.0); }
#[test]
fn test_weighted_set_cover() {
let mut sets = Vec::new();
sets.push([1, 2, 3].iter().cloned().collect());
sets.push([1].iter().cloned().collect());
sets.push([2, 3].iter().cloned().collect());
let weights = vec![10.0, 1.0, 3.0];
let universe: HashSet<_> = (1..=3).collect();
let instance = SetCoverInstance::new(sets, weights, universe);
let (solution, _weight) = solve(&instance);
let mut covered = HashSet::new();
for &idx in &solution {
covered.extend(&instance.sets[idx]);
}
assert_eq!(covered, instance.universe);
assert!(!solution.contains(&0));
}
#[test]
fn test_empty_instance() {
let instance = SetCoverInstance::new(Vec::new(), Vec::new(), HashSet::new());
let (solution, _weight) = solve(&instance);
assert!(solution.is_empty());
assert_eq!(_weight, 0.0);
}
}