Skip to main content

problemreductions/models/misc/
knapsack.rs

1//! Knapsack problem implementation.
2//!
3//! The 0-1 Knapsack problem asks for a subset of items that maximizes
4//! total value while respecting a weight capacity constraint.
5
6use crate::registry::{FieldInfo, ProblemSchemaEntry};
7use crate::traits::{OptimizationProblem, Problem};
8use crate::types::{Direction, SolutionSize};
9use serde::{Deserialize, Serialize};
10
11inventory::submit! {
12    ProblemSchemaEntry {
13        name: "Knapsack",
14        display_name: "Knapsack",
15        aliases: &[],
16        dimensions: &[],
17        module_path: module_path!(),
18        description: "Select items to maximize total value subject to weight capacity constraint",
19        fields: &[
20            FieldInfo { name: "weights", type_name: "Vec<i64>", description: "Nonnegative item weights w_i" },
21            FieldInfo { name: "values", type_name: "Vec<i64>", description: "Nonnegative item values v_i" },
22            FieldInfo { name: "capacity", type_name: "i64", description: "Nonnegative knapsack capacity C" },
23        ],
24    }
25}
26
27/// The 0-1 Knapsack problem.
28///
29/// Given `n` items, each with nonnegative weight `w_i` and nonnegative value `v_i`,
30/// and a nonnegative capacity `C`,
31/// find a subset `S ⊆ {0, ..., n-1}` such that `∑_{i∈S} w_i ≤ C`,
32/// maximizing `∑_{i∈S} v_i`.
33///
34/// # Representation
35///
36/// Each item has a binary variable: `x_i = 1` if item `i` is selected, `0` otherwise.
37///
38/// # Example
39///
40/// ```
41/// use problemreductions::models::misc::Knapsack;
42/// use problemreductions::{Problem, Solver, BruteForce};
43///
44/// let problem = Knapsack::new(vec![2, 3, 4, 5], vec![3, 4, 5, 7], 7);
45/// let solver = BruteForce::new();
46/// let solution = solver.find_best(&problem);
47/// assert!(solution.is_some());
48/// ```
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct Knapsack {
51    #[serde(deserialize_with = "nonnegative_i64_vec::deserialize")]
52    weights: Vec<i64>,
53    #[serde(deserialize_with = "nonnegative_i64_vec::deserialize")]
54    values: Vec<i64>,
55    #[serde(deserialize_with = "nonnegative_i64::deserialize")]
56    capacity: i64,
57}
58
59impl Knapsack {
60    /// Create a new Knapsack instance.
61    ///
62    /// # Panics
63    /// Panics if `weights` and `values` have different lengths, or if any
64    /// weight, value, or the capacity is negative.
65    pub fn new(weights: Vec<i64>, values: Vec<i64>, capacity: i64) -> Self {
66        assert_eq!(
67            weights.len(),
68            values.len(),
69            "weights and values must have the same length"
70        );
71        assert!(
72            weights.iter().all(|&weight| weight >= 0),
73            "Knapsack weights must be nonnegative"
74        );
75        assert!(
76            values.iter().all(|&value| value >= 0),
77            "Knapsack values must be nonnegative"
78        );
79        assert!(capacity >= 0, "Knapsack capacity must be nonnegative");
80        Self {
81            weights,
82            values,
83            capacity,
84        }
85    }
86
87    /// Returns the item weights.
88    pub fn weights(&self) -> &[i64] {
89        &self.weights
90    }
91
92    /// Returns the item values.
93    pub fn values(&self) -> &[i64] {
94        &self.values
95    }
96
97    /// Returns the knapsack capacity.
98    pub fn capacity(&self) -> i64 {
99        self.capacity
100    }
101
102    /// Returns the number of items.
103    pub fn num_items(&self) -> usize {
104        self.weights.len()
105    }
106
107    /// Returns the number of binary slack bits used by the QUBO encoding.
108    ///
109    /// For positive capacity this is `floor(log2(C)) + 1`; for zero capacity we
110    /// keep one slack bit so the encoding shape remains uniform.
111    pub fn num_slack_bits(&self) -> usize {
112        if self.capacity == 0 {
113            1
114        } else {
115            (u64::BITS - (self.capacity as u64).leading_zeros()) as usize
116        }
117    }
118}
119
120impl Problem for Knapsack {
121    const NAME: &'static str = "Knapsack";
122    type Metric = SolutionSize<i64>;
123
124    fn variant() -> Vec<(&'static str, &'static str)> {
125        crate::variant_params![]
126    }
127
128    fn dims(&self) -> Vec<usize> {
129        vec![2; self.num_items()]
130    }
131
132    fn evaluate(&self, config: &[usize]) -> SolutionSize<i64> {
133        if config.len() != self.num_items() {
134            return SolutionSize::Invalid;
135        }
136        if config.iter().any(|&v| v >= 2) {
137            return SolutionSize::Invalid;
138        }
139        let total_weight: i64 = config
140            .iter()
141            .enumerate()
142            .filter(|(_, &x)| x == 1)
143            .map(|(i, _)| self.weights[i])
144            .sum();
145        if total_weight > self.capacity {
146            return SolutionSize::Invalid;
147        }
148        let total_value: i64 = config
149            .iter()
150            .enumerate()
151            .filter(|(_, &x)| x == 1)
152            .map(|(i, _)| self.values[i])
153            .sum();
154        SolutionSize::Valid(total_value)
155    }
156}
157
158impl OptimizationProblem for Knapsack {
159    type Value = i64;
160
161    fn direction(&self) -> Direction {
162        Direction::Maximize
163    }
164}
165
166crate::declare_variants! {
167    default opt Knapsack => "2^(num_items / 2)",
168}
169
170mod nonnegative_i64 {
171    use serde::de::Error;
172    use serde::{Deserialize, Deserializer};
173
174    pub fn deserialize<'de, D>(deserializer: D) -> Result<i64, D::Error>
175    where
176        D: Deserializer<'de>,
177    {
178        let value = i64::deserialize(deserializer)?;
179        if value < 0 {
180            return Err(D::Error::custom(format!(
181                "expected nonnegative integer, got {value}"
182            )));
183        }
184        Ok(value)
185    }
186}
187
188mod nonnegative_i64_vec {
189    use serde::de::Error;
190    use serde::{Deserialize, Deserializer};
191
192    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<i64>, D::Error>
193    where
194        D: Deserializer<'de>,
195    {
196        let values = Vec::<i64>::deserialize(deserializer)?;
197        if let Some(value) = values.iter().copied().find(|value| *value < 0) {
198            return Err(D::Error::custom(format!(
199                "expected nonnegative integers, got {value}"
200            )));
201        }
202        Ok(values)
203    }
204}
205
206#[cfg(test)]
207#[path = "../../unit_tests/models/misc/knapsack.rs"]
208mod tests;