use crate::{Cost, Error, Result, SolverStatus};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SetCoverProblem {
pub num_elements: usize,
pub sets: Vec<(Cost, Vec<usize>)>,
}
impl SetCoverProblem {
pub fn new(num_elements: usize, sets: Vec<(Cost, Vec<usize>)>) -> Result<Self> {
for (_, elements) in &sets {
for &e in elements {
if e >= num_elements {
return Err(Error::invalid_input(format!(
"element {} out of range [0, {})",
e, num_elements
)));
}
}
}
Ok(Self { num_elements, sets })
}
pub fn unit_cost(num_elements: usize, sets: Vec<Vec<usize>>) -> Result<Self> {
let sets_with_cost = sets.into_iter()
.map(|s| (1, s))
.collect();
Self::new(num_elements, sets_with_cost)
}
pub fn num_sets(&self) -> usize {
self.sets.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SetCoverSolution {
pub selected: Vec<usize>,
pub total_cost: Cost,
pub status: SolverStatus,
}
pub fn greedy(problem: &SetCoverProblem) -> Result<SetCoverSolution> {
let mut uncovered: HashSet<usize> = (0..problem.num_elements).collect();
let mut selected = Vec::new();
let mut total_cost: Cost = 0;
while !uncovered.is_empty() {
let mut best_set = None;
let mut best_ratio = f64::INFINITY;
for (idx, (cost, elements)) in problem.sets.iter().enumerate() {
if selected.contains(&idx) {
continue;
}
let new_covered: usize = elements.iter()
.filter(|e| uncovered.contains(e))
.count();
if new_covered == 0 {
continue;
}
let ratio = *cost as f64 / new_covered as f64;
if ratio < best_ratio {
best_ratio = ratio;
best_set = Some(idx);
}
}
match best_set {
Some(idx) => {
let (cost, elements) = &problem.sets[idx];
selected.push(idx);
total_cost += cost;
for &e in elements {
uncovered.remove(&e);
}
}
None => {
return Err(Error::infeasible(
"not all elements can be covered"
));
}
}
}
Ok(SetCoverSolution {
selected,
total_cost,
status: SolverStatus::Feasible, })
}
pub fn solve(problem: &SetCoverProblem) -> Result<SetCoverSolution> {
greedy(problem)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_cover() {
let problem = SetCoverProblem::new(
5,
vec![
(1, vec![0, 1, 2]),
(1, vec![2, 3]),
(1, vec![3, 4]),
(1, vec![4, 0]),
],
).unwrap();
let solution = solve(&problem).unwrap();
let mut covered = HashSet::new();
for &idx in &solution.selected {
for &e in &problem.sets[idx].1 {
covered.insert(e);
}
}
assert_eq!(covered.len(), 5);
}
#[test]
fn test_unit_cost() {
let problem = SetCoverProblem::unit_cost(
3,
vec![vec![0, 1], vec![1, 2], vec![0, 2]],
).unwrap();
let solution = solve(&problem).unwrap();
assert!(solution.total_cost <= 2); }
#[test]
fn test_infeasible() {
let problem = SetCoverProblem::new(
3,
vec![(1, vec![0]), (1, vec![1])],
).unwrap();
let result = solve(&problem);
assert!(result.is_err());
}
}