use thiserror::Error;
fn number_combinations(n: usize, mut k: usize) -> usize {
if n < k {
return 0;
}
if k == 0 {
return 1;
}
if k == n {
return 1;
}
if k > n / 2 {
k = n - k;
}
let mut result = 1;
for i in 1..=k {
result = result * (n - k + i) / i;
}
result
}
#[derive(Debug, Error, PartialEq)]
pub enum CombinationGetIndexError {
#[error("combination must have length {0}, but has length {1}")]
IncorrectLength(usize, usize),
#[error("combination must have elements in the range 0..{0}, but has element {1}")]
ValueTooLarge(usize, usize),
}
#[derive(Debug, Error, PartialEq)]
pub enum CombinationNewError {
#[error("N must be greater than 0")]
ZeroOptions,
#[error("Combinations must N >= t, but N = {0} and t = {1}")]
CombinationTooLarge(usize, usize),
}
#[derive(Clone, Debug)]
pub struct Combinations {
n: usize,
t: usize,
current: Option<Vec<usize>>,
}
impl Combinations {
pub fn new(n: usize, t: usize) -> Result<Self, CombinationNewError> {
if n == 0 {
return Err(CombinationNewError::ZeroOptions);
}
if t > n {
return Err(CombinationNewError::CombinationTooLarge(n, t));
}
Ok(Self {
n,
t,
current: None,
})
}
pub fn number_combinations(&self) -> usize {
number_combinations(self.n, self.t)
}
pub fn get_index(&self, combination: &[usize]) -> Result<usize, CombinationGetIndexError> {
let k = combination.len();
if k != self.t {
return Err(CombinationGetIndexError::IncorrectLength(self.t, k));
}
let mut index = 0;
let mut item_in_check = 0;
let n = self.n - 1;
for (offset, item) in combination.iter().enumerate() {
if *item > self.n - 1 {
return Err(CombinationGetIndexError::ValueTooLarge(self.n - 1, *item));
}
let offset = offset + 1;
while item_in_check < *item {
index += number_combinations(n - item_in_check, k - offset);
item_in_check += 1
}
item_in_check += 1
}
Ok(index)
}
pub fn at_index(&self, index: usize) -> Option<Vec<usize>> {
if index >= number_combinations(self.n, self.t) {
return None;
}
let mut result = Vec::new();
let mut a = self.n;
let mut b = self.t;
let mut x = number_combinations(self.n, self.t) - 1 - index;
for _ in 0..self.t {
a -= 1;
while number_combinations(a, b) > x {
a -= 1;
}
result.push(self.n - 1 - a);
x -= number_combinations(a, b);
b -= 1;
}
Some(result)
}
}
impl Iterator for Combinations {
type Item = Vec<usize>;
fn next(&mut self) -> Option<Self::Item> {
if self.current.is_none() {
let current: Vec<usize> = (0..self.t).collect();
self.current = Some(current.clone());
return Some(current);
}
let current = self.current.as_mut().unwrap();
let mut i = self.t;
while i > 0 {
i -= 1;
if current[i] != self.n - self.t + i {
break;
}
}
if self.t == 0 || i == 0 && current[i] == self.n - self.t {
return None;
}
current[i] += 1;
for j in i + 1..self.t {
current[j] = current[j - 1] + 1;
}
Some(current.clone())
}
}
pub fn insert_element_into_reduced_combination(
element_to_insert: usize,
combination_without_element: &[usize],
) -> (usize, Vec<usize>) {
let mut result = Vec::with_capacity(combination_without_element.len() + 1);
let mut element_has_been_inserted = false;
let mut index = combination_without_element.len();
for (i, element) in combination_without_element.iter().enumerate() {
if !element_has_been_inserted {
if *element >= element_to_insert {
element_has_been_inserted = true;
index = i;
result.push(element_to_insert);
result.push(*element + 1);
} else {
result.push(*element);
}
} else {
result.push(*element + 1)
}
}
if !element_has_been_inserted {
result.push(element_to_insert);
}
(index, result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_combinations_5_choose_3() {
let combinations = Combinations::new(5, 3).unwrap();
let result: Vec<_> = combinations.collect();
assert_eq!(
result,
vec![
vec![0, 1, 2],
vec![0, 1, 3],
vec![0, 1, 4],
vec![0, 2, 3],
vec![0, 2, 4],
vec![0, 3, 4],
vec![1, 2, 3],
vec![1, 2, 4],
vec![1, 3, 4],
vec![2, 3, 4],
]
);
}
#[test]
fn test_combinations_4_choose_2() {
let combinations = Combinations::new(4, 2).unwrap();
let result: Vec<_> = combinations.collect();
assert_eq!(
result,
vec![
vec![0, 1],
vec![0, 2],
vec![0, 3],
vec![1, 2],
vec![1, 3],
vec![2, 3],
]
);
}
#[test]
fn test_combinations_are_lexicographically_sorted() {
let combinations = Combinations::new(5, 3).unwrap();
let result: Vec<_> = combinations.collect();
let mut sorted_result = result.clone();
sorted_result.sort();
assert_eq!(result, sorted_result);
}
#[test]
fn test_combinations_number_combinations() {
let combinations = Combinations::new(5, 3).unwrap();
assert_eq!(combinations.number_combinations(), 10);
}
#[test]
fn test_combinations_get_index() {
let combinations = Combinations::new(5, 3).unwrap();
for (index, combination) in combinations.clone().enumerate() {
assert_eq!(combinations.get_index(&combination).unwrap(), index);
}
}
#[test]
fn test_combinations_at_index() {
let combinations = Combinations::new(5, 3).unwrap();
for (index, combination) in combinations.clone().enumerate() {
assert_eq!(combinations.at_index(index).unwrap(), combination);
}
}
#[test]
fn test_combinations_at_index_out_of_bounds() {
let combinations = Combinations::new(5, 3).unwrap();
assert_eq!(combinations.at_index(10), None);
}
#[test]
fn test_combinations_get_index_incorrect_length() {
let combinations = Combinations::new(5, 3).unwrap();
assert_eq!(
combinations.get_index(&[0, 1]),
Err(CombinationGetIndexError::IncorrectLength(3, 2))
);
}
#[test]
fn test_combinations_get_index_value_too_large() {
let combinations = Combinations::new(5, 3).unwrap();
assert_eq!(
combinations.get_index(&[0, 1, 5]),
Err(CombinationGetIndexError::ValueTooLarge(4, 5))
);
}
#[test]
fn test_combinations_new_combination_too_large() {
let combinations = Combinations::new(5, 6);
assert_eq!(
combinations.unwrap_err(),
CombinationNewError::CombinationTooLarge(5, 6)
);
}
#[test]
fn test_combinations_new_zero_options() {
let combinations = Combinations::new(0, 3);
assert_eq!(combinations.unwrap_err(), CombinationNewError::ZeroOptions);
}
#[test]
fn test_combinations_new_zero_combination_size() {
let _combinations = Combinations::new(5, 0);
}
#[test]
fn test_insert_element_into_reduced_combination() {
assert_eq!(
insert_element_into_reduced_combination(1, &[0, 2, 3]),
(1, vec![0, 1, 3, 4])
);
assert_eq!(
insert_element_into_reduced_combination(0, &[1, 2, 3]),
(0, vec![0, 2, 3, 4])
);
assert_eq!(
insert_element_into_reduced_combination(3, &[0, 1, 2]),
(3, vec![0, 1, 2, 3])
);
}
}