use crate::error::OptimizeError;
pub type KnapsackResult<T> = Result<T, OptimizeError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct KnapsackItem {
pub weight: u64,
pub value: u64,
}
pub fn knapsack_dp(items: &[KnapsackItem], capacity: u64) -> KnapsackResult<(u64, Vec<usize>)> {
let n = items.len();
let w = capacity as usize;
let table_size = (n + 1).saturating_mul(w + 1);
if table_size > 500_000_000 {
return Err(OptimizeError::InvalidInput(format!(
"DP table size {table_size} exceeds 500M; use branch-and-bound for large capacities"
)));
}
let mut dp = vec![0u64; (n + 1) * (w + 1)];
for i in 1..=n {
let iw = items[i - 1].weight as usize;
let iv = items[i - 1].value;
for c in 0..=w {
let without = dp[(i - 1) * (w + 1) + c];
let with_item = if iw <= c {
dp[(i - 1) * (w + 1) + c - iw].saturating_add(iv)
} else {
0
};
dp[i * (w + 1) + c] = without.max(with_item);
}
}
let mut selected = Vec::new();
let mut remaining = w;
for i in (1..=n).rev() {
if dp[i * (w + 1) + remaining] != dp[(i - 1) * (w + 1) + remaining] {
selected.push(i - 1);
let iw = items[i - 1].weight as usize;
remaining = remaining.saturating_sub(iw);
}
}
selected.reverse();
let total = dp[n * (w + 1) + w];
Ok((total, selected))
}
pub fn fractional_knapsack(items: &[KnapsackItem], capacity: u64) -> f64 {
if capacity == 0 || items.is_empty() {
return 0.0;
}
let mut indexed: Vec<(usize, f64)> = items
.iter()
.enumerate()
.map(|(i, it)| {
let ratio = if it.weight == 0 {
f64::INFINITY
} else {
it.value as f64 / it.weight as f64
};
(i, ratio)
})
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut remaining = capacity as f64;
let mut total_value = 0.0;
for (idx, _ratio) in &indexed {
let item = &items[*idx];
if item.weight == 0 {
total_value += item.value as f64;
continue;
}
let take = (remaining / item.weight as f64).min(1.0);
total_value += take * item.value as f64;
remaining -= take * item.weight as f64;
if remaining <= 0.0 {
break;
}
}
total_value
}
pub fn knapsack_greedy(items: &[KnapsackItem], capacity: u64) -> (u64, Vec<usize>) {
if capacity == 0 || items.is_empty() {
return (0, vec![]);
}
let mut indexed: Vec<(usize, f64)> = items
.iter()
.enumerate()
.map(|(i, it)| {
let ratio = if it.weight == 0 {
f64::INFINITY
} else {
it.value as f64 / it.weight as f64
};
(i, ratio)
})
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut remaining = capacity;
let mut total = 0u64;
let mut selected = Vec::new();
for (idx, _) in &indexed {
let item = &items[*idx];
if item.weight <= remaining {
selected.push(*idx);
remaining -= item.weight;
total += item.value;
}
}
selected.sort_unstable();
(total, selected)
}
#[derive(Debug, Clone)]
struct BbNode {
level: usize,
value: u64,
weight: u64,
bound: f64,
taken: Vec<bool>,
}
fn lp_bound(
items: &[KnapsackItem],
sorted_indices: &[usize],
level: usize,
value: u64,
weight: u64,
capacity: u64,
) -> f64 {
if weight > capacity {
return 0.0;
}
let mut remaining = (capacity - weight) as f64;
let mut bound = value as f64;
for &idx in sorted_indices.iter().skip(level) {
let item = &items[idx];
if item.weight as f64 <= remaining {
bound += item.value as f64;
remaining -= item.weight as f64;
} else {
if item.weight > 0 {
bound += remaining * (item.value as f64 / item.weight as f64);
}
break;
}
}
bound
}
pub fn knapsack_branch_bound(
items: &[KnapsackItem],
capacity: u64,
) -> KnapsackResult<(u64, Vec<usize>)> {
let n = items.len();
if n == 0 || capacity == 0 {
return Ok((0, vec![]));
}
let mut sorted_indices: Vec<usize> = (0..n).collect();
sorted_indices.sort_by(|&a, &b| {
let ra = if items[a].weight == 0 {
f64::INFINITY
} else {
items[a].value as f64 / items[a].weight as f64
};
let rb = if items[b].weight == 0 {
f64::INFINITY
} else {
items[b].value as f64 / items[b].weight as f64
};
rb.partial_cmp(&ra).unwrap_or(std::cmp::Ordering::Equal)
});
let mut best_value = 0u64;
let mut best_taken = vec![false; n];
{
let (gv, gi) = knapsack_greedy(items, capacity);
best_value = gv;
for idx in gi {
best_taken[idx] = true;
}
}
let root = BbNode {
level: 0,
value: 0,
weight: 0,
bound: lp_bound(items, &sorted_indices, 0, 0, 0, capacity),
taken: vec![false; n],
};
let mut stack: Vec<BbNode> = vec![root];
while let Some(node) = stack.pop() {
if node.level == n {
if node.value > best_value {
best_value = node.value;
best_taken = node.taken.clone();
}
continue;
}
if node.bound <= best_value as f64 {
continue;
}
let item_idx = sorted_indices[node.level];
let item = &items[item_idx];
if node.weight + item.weight <= capacity {
let mut taken_with = node.taken.clone();
taken_with[item_idx] = true;
let new_value = node.value + item.value;
let new_weight = node.weight + item.weight;
let new_bound = lp_bound(
items,
&sorted_indices,
node.level + 1,
new_value,
new_weight,
capacity,
);
if new_bound > best_value as f64 {
stack.push(BbNode {
level: node.level + 1,
value: new_value,
weight: new_weight,
bound: new_bound,
taken: taken_with,
});
}
}
let excl_bound = lp_bound(
items,
&sorted_indices,
node.level + 1,
node.value,
node.weight,
capacity,
);
if excl_bound > best_value as f64 {
stack.push(BbNode {
level: node.level + 1,
value: node.value,
weight: node.weight,
bound: excl_bound,
taken: node.taken.clone(),
});
}
}
let selected: Vec<usize> = (0..n).filter(|&i| best_taken[i]).collect();
Ok((best_value, selected))
}
#[derive(Debug, Clone)]
pub struct MultiKnapsackItem {
pub weights: Vec<u64>,
pub value: u64,
}
pub fn multi_knapsack_greedy(
items: &[MultiKnapsackItem],
capacities: &[u64],
) -> KnapsackResult<(u64, Vec<usize>)> {
let n = items.len();
let d = capacities.len();
if n == 0 || d == 0 {
return Ok((0, vec![]));
}
for (i, item) in items.iter().enumerate() {
if item.weights.len() != d {
return Err(OptimizeError::InvalidInput(format!(
"Item {i} has {} weight dimensions but capacities has {d}",
item.weights.len()
)));
}
}
let mut indexed: Vec<(usize, f64)> = items
.iter()
.enumerate()
.map(|(i, it)| {
let norm_sq: f64 = it.weights.iter().map(|&w| (w as f64).powi(2)).sum();
let norm = norm_sq.sqrt().max(1e-12);
(i, it.value as f64 / norm)
})
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut remaining = capacities.to_vec();
let mut selected = vec![false; n];
let mut total = 0u64;
for (idx, _) in &indexed {
let item = &items[*idx];
if item.weights.iter().enumerate().all(|(dim, &w)| w <= remaining[dim]) {
selected[*idx] = true;
total += item.value;
for (dim, &w) in item.weights.iter().enumerate() {
remaining[dim] -= w;
}
}
}
let mut improved = true;
while improved {
improved = false;
for out_idx in 0..n {
if !selected[out_idx] {
continue;
}
for in_idx in 0..n {
if selected[in_idx] {
continue;
}
let delta_v = items[in_idx].value as i64 - items[out_idx].value as i64;
if delta_v <= 0 {
continue;
}
let feasible = items[in_idx]
.weights
.iter()
.enumerate()
.all(|(dim, &w)| {
let freed = items[out_idx].weights[dim];
freed + remaining[dim] >= w
});
if feasible {
for dim in 0..d {
remaining[dim] += items[out_idx].weights[dim];
remaining[dim] -= items[in_idx].weights[dim];
}
total = total
.saturating_sub(items[out_idx].value)
.saturating_add(items[in_idx].value);
selected[out_idx] = false;
selected[in_idx] = true;
improved = true;
break;
}
}
if improved {
break;
}
}
}
let result: Vec<usize> = (0..n).filter(|&i| selected[i]).collect();
Ok((total, result))
}
#[cfg(test)]
mod tests {
use super::*;
fn classic_items() -> Vec<KnapsackItem> {
vec![
KnapsackItem { weight: 2, value: 3 },
KnapsackItem { weight: 3, value: 4 },
KnapsackItem { weight: 2, value: 5 },
KnapsackItem { weight: 3, value: 6 },
]
}
#[test]
fn test_dp_classic() {
let items = classic_items();
let (val, sel) = knapsack_dp(&items, 5).expect("unexpected None or Err");
assert_eq!(val, 9, "expected value 9, got {val}");
let total_weight: u64 = sel.iter().map(|&i| items[i].weight).sum();
assert!(total_weight <= 5);
let total_val: u64 = sel.iter().map(|&i| items[i].value).sum();
assert_eq!(total_val, val);
}
#[test]
fn test_dp_empty() {
let (val, sel) = knapsack_dp(&[], 10).expect("unexpected None or Err");
assert_eq!(val, 0);
assert!(sel.is_empty());
}
#[test]
fn test_dp_zero_capacity() {
let items = classic_items();
let (val, sel) = knapsack_dp(&items, 0).expect("unexpected None or Err");
assert_eq!(val, 0);
assert!(sel.is_empty());
}
#[test]
fn test_fractional_knapsack() {
let items = classic_items();
let val = fractional_knapsack(&items, 5);
assert!(val >= 9.0 - 1e-9);
}
#[test]
fn test_greedy_knapsack() {
let items = classic_items();
let (val, sel) = knapsack_greedy(&items, 5);
assert!(val > 0);
let total_weight: u64 = sel.iter().map(|&i| items[i].weight).sum();
assert!(total_weight <= 5);
}
#[test]
fn test_branch_bound_classic() {
let items = classic_items();
let (val, sel) = knapsack_branch_bound(&items, 5).expect("unexpected None or Err");
assert_eq!(val, 9);
let total_weight: u64 = sel.iter().map(|&i| items[i].weight).sum();
assert!(total_weight <= 5);
}
#[test]
fn test_bb_equals_dp() {
let items = vec![
KnapsackItem { weight: 1, value: 6 },
KnapsackItem { weight: 2, value: 10 },
KnapsackItem { weight: 3, value: 12 },
];
let cap = 5;
let (dp_val, _) = knapsack_dp(&items, cap).expect("unexpected None or Err");
let (bb_val, _) = knapsack_branch_bound(&items, cap).expect("unexpected None or Err");
assert_eq!(dp_val, bb_val, "DP and B&B should agree");
}
#[test]
fn test_multi_knapsack() {
let items = vec![
MultiKnapsackItem { weights: vec![2, 1], value: 5 },
MultiKnapsackItem { weights: vec![1, 2], value: 5 },
MultiKnapsackItem { weights: vec![3, 3], value: 8 },
];
let caps = vec![4, 4];
let (val, sel) = multi_knapsack_greedy(&items, &caps).expect("unexpected None or Err");
assert!(val > 0);
for dim in 0..2 {
let used: u64 = sel.iter().map(|&i| items[i].weights[dim]).sum();
assert!(used <= caps[dim]);
}
}
#[test]
fn test_fractional_zero_capacity() {
let items = classic_items();
assert_eq!(fractional_knapsack(&items, 0), 0.0);
}
#[test]
fn test_all_items_fit() {
let items = classic_items();
let total_weight: u64 = items.iter().map(|i| i.weight).sum();
let total_value: u64 = items.iter().map(|i| i.value).sum();
let (val, sel) = knapsack_dp(&items, total_weight + 100).expect("unexpected None or Err");
assert_eq!(val, total_value);
assert_eq!(sel.len(), items.len());
}
}