use num_bigint::BigUint;
use num_integer::Integer;
use num_traits::{One, Zero};
use std::ops::{Add, Mul};
#[derive(Debug)]
pub(super) struct Chain(Vec<BigUint>);
impl Add<BigUint> for Chain {
type Output = Self;
fn add(mut self, k: BigUint) -> Self {
self.0.push(k + self.0.last().expect("chain is not empty"));
self
}
}
impl Mul<Chain> for Chain {
type Output = Self;
fn mul(mut self, mut other: Chain) -> Self {
let last = self.0.last().expect("chain is not empty");
assert!(other.0.remove(0).is_one());
for w in other.0.iter_mut() {
*w *= last;
}
self.0.append(&mut other.0);
self
}
}
pub(super) fn find_shortest_chain(n: BigUint) -> Vec<BigUint> {
minchain(n).0
}
fn minchain(n: BigUint) -> Chain {
let log_n = n.bits() - 1;
if n == BigUint::one() << log_n {
Chain((0..=log_n).map(|i| BigUint::one() << i).collect())
} else if n == BigUint::from(3u32) {
Chain(vec![BigUint::one(), BigUint::from(2u32), n])
} else {
let k = &n / (BigUint::one() << (log_n / 2));
chain(&n, k)
}
}
fn chain(n: &BigUint, k: BigUint) -> Chain {
let (q, r) = n.div_rem(&k);
if r.is_zero() {
minchain(k) * minchain(q)
} else if r.is_one() {
minchain(k) * minchain(q) + r
} else {
chain(&k, r.clone()) * minchain(q) + r
}
}
#[cfg(test)]
mod tests {
use num_bigint::BigUint;
use super::minchain;
#[test]
fn minchain_87() {
let chain = minchain(BigUint::from(87u32));
assert_eq!(
chain.0,
vec![
BigUint::from(1u32),
BigUint::from(2u32),
BigUint::from(3u32),
BigUint::from(6u32),
BigUint::from(7u32),
BigUint::from(10u32),
BigUint::from(20u32),
BigUint::from(40u32),
BigUint::from(80u32),
BigUint::from(87u32),
]
);
}
#[test]
fn minchain_384() {
let chain = minchain(BigUint::from(384u32));
assert_eq!(
chain.0,
vec![
BigUint::from(1u32),
BigUint::from(2u32),
BigUint::from(3u32),
BigUint::from(6u32),
BigUint::from(12u32),
BigUint::from(24u32),
BigUint::from(48u32),
BigUint::from(96u32),
BigUint::from(192u32),
BigUint::from(384u32),
]
);
}
}