use num_traits::{ConstOne, ConstZero, PrimInt};
use std::borrow::Borrow;
use std::cmp::Ordering;
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub struct DigitsIter<T> {
num: T,
radix: T,
front_weight: T,
}
impl<T: PrimInt + ConstZero + ConstOne> DigitsIter<T> {
pub fn new(num: T, radix: u8) -> Self {
if radix < 2 {
panic!("Radix must be at least 2.");
}
let radix = T::from(radix).expect("Radix must fit in the type T.");
let length = match num.cmp(&T::ZERO) {
Ordering::Less => panic!("Integer must be non-negative."),
Ordering::Equal => 1,
Ordering::Greater => num.to_u128().unwrap().ilog(radix.to_u128().unwrap()) + 1,
};
let front_weight = radix.pow(length - 1);
Self {
num,
radix,
front_weight,
}
}
}
impl<T: PrimInt + ConstZero> Iterator for DigitsIter<T> {
type Item = u8;
fn next(&mut self) -> Option<Self::Item> {
if self.num == T::ZERO && self.front_weight == T::ZERO {
None
} else {
let next_digit = self.num % self.radix;
self.num = self.num / self.radix;
self.front_weight = self.front_weight / self.radix;
Some(next_digit.to_u8().unwrap())
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let length = if self.front_weight == T::ZERO {
0
} else {
self.front_weight
.to_u128()
.unwrap()
.ilog(self.radix.to_u128().unwrap()) as usize
+ 1
};
(length, Some(length))
}
fn count(self) -> usize {
self.len()
}
}
impl<T: PrimInt + ConstZero> DoubleEndedIterator for DigitsIter<T> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.num == T::ZERO && self.front_weight == T::ZERO {
None
} else {
let next_digit = self.num / self.front_weight;
self.num = self.num % self.front_weight;
self.front_weight = self.front_weight / self.radix;
Some(next_digit.to_u8().unwrap())
}
}
}
impl<T: PrimInt + ConstZero> ExactSizeIterator for DigitsIter<T> {}
pub fn digits<T: PrimInt + ConstZero + ConstOne>(n: T, radix: u8) -> DigitsIter<T> {
DigitsIter::new(n, radix)
}
pub fn digits_to_int<T, U, V>(digits: T, radix: u8) -> V
where
T: IntoIterator<Item = U>,
U: Borrow<u8>,
V: PrimInt + ConstZero + ConstOne,
{
if radix < 2 {
panic!("Radix must be at least 2.");
}
let mut result = V::ZERO;
let radix = V::from(radix).expect("Radix must fit in the type V.");
let mut base = V::ONE;
for digit in digits {
let digit = V::from(*digit.borrow()).expect("Digit must fit in the type V.");
if digit < V::ZERO {
panic!("Digits must be non-negative.");
} else if digit >= radix {
panic!("Digits must be less than the radix.");
}
result = result + base * digit;
base = base * radix;
}
result
}
pub fn is_palindrome<T>(n: T, radix: u8) -> bool
where
T: PrimInt + ConstZero + ConstOne,
{
let mut digits = digits(n, radix);
let mut last_digit = digits.next();
loop {
match last_digit {
Some(last) => {
match digits.next_back() {
Some(front) => {
if last != front {
return false; }
last_digit = None;
}
None => return true, }
}
None => {
match digits.next() {
Some(back) => {
last_digit = Some(back);
}
None => return true, }
}
}
}
}
pub fn is_permutation<T>(n: T, m: T, radix: u8) -> bool
where
T: PrimInt + ConstZero + ConstOne,
{
let mut seen_digits = [0_i16; 256];
for digit in digits(n, radix) {
seen_digits[digit as usize] += 1;
}
for digit in digits(m, radix) {
seen_digits[digit as usize] -= 1;
}
seen_digits.into_iter().all(|count| count == 0)
}
pub fn reverse<T: PrimInt + ConstZero + ConstOne>(n: T, radix: u8) -> T {
digits_to_int(digits(n, radix).rev(), radix)
}