integer_sqrt/
lib.rs

1//!
2//! This module contains the single trait [`IntegerSquareRoot`] and implements it for primitive
3//! integer types.
4//!
5//! # Example
6//!
7//! ```
8//! extern crate integer_sqrt;
9//! // `use` trait to get functionality
10//! use integer_sqrt::IntegerSquareRoot;
11//!
12//! # fn main() {
13//! assert_eq!(4u8.integer_sqrt(), 2);
14//! # }
15//! ```
16//!
17//! [`IntegerSquareRoot`]: ./trait.IntegerSquareRoot.html
18#![no_std]
19
20/// A trait implementing integer square root.
21pub trait IntegerSquareRoot {
22    /// Find the integer square root.
23    ///
24    /// See [Integer_square_root on wikipedia][wiki_article] for more information (and also the
25    /// source of this algorithm)
26    ///
27    /// # Panics
28    ///
29    /// For negative numbers (`i` family) this function will panic on negative input
30    ///
31    /// [wiki_article]: https://en.wikipedia.org/wiki/Integer_square_root
32    fn integer_sqrt(&self) -> Self
33    where
34        Self: Sized,
35    {
36        self.integer_sqrt_checked()
37            .expect("cannot calculate square root of negative number")
38    }
39
40    /// Find the integer square root, returning `None` if the number is negative (this can never
41    /// happen for unsigned types).
42    fn integer_sqrt_checked(&self) -> Option<Self>
43    where
44        Self: Sized;
45}
46
47impl<T: num_traits::PrimInt> IntegerSquareRoot for T {
48    fn integer_sqrt_checked(&self) -> Option<Self> {
49        use core::cmp::Ordering;
50        match self.cmp(&T::zero()) {
51            // Hopefully this will be stripped for unsigned numbers (impossible condition)
52            Ordering::Less => return None,
53            Ordering::Equal => return Some(T::zero()),
54            _ => {}
55        }
56
57        // Compute bit, the largest power of 4 <= n
58        let max_shift: u32 = T::zero().leading_zeros() - 1;
59        let shift: u32 = (max_shift - self.leading_zeros()) & !1;
60        let mut bit = T::one().unsigned_shl(shift);
61
62        // Algorithm based on the implementation in:
63        // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)
64        // Note that result/bit are logically unsigned (even if T is signed).
65        let mut n = *self;
66        let mut result = T::zero();
67        while bit != T::zero() {
68            if n >= (result + bit) {
69                n = n - (result + bit);
70                result = result.unsigned_shr(1) + bit;
71            } else {
72                result = result.unsigned_shr(1);
73            }
74            bit = bit.unsigned_shr(2);
75        }
76        Some(result)
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::IntegerSquareRoot;
83    use core::{i8, u16, u64, u8};
84
85    macro_rules! gen_tests {
86        ($($type:ty => $fn_name:ident),*) => {
87            $(
88                #[test]
89                fn $fn_name() {
90                    let newton_raphson = |val, square| 0.5 * (val + (square / val as $type) as f64);
91                    let max_sqrt = {
92                        let square = <$type>::max_value();
93                        let mut value = (square as f64).sqrt();
94                        for _ in 0..2 {
95                            value = newton_raphson(value, square);
96                        }
97                        let mut value = value as $type;
98                        // make sure we are below the max value (this is how integer square
99                        // root works)
100                        if value.checked_mul(value).is_none() {
101                            value -= 1;
102                        }
103                        value
104                    };
105                    let tests: [($type, $type); 9] = [
106                        (0, 0),
107                        (1, 1),
108                        (2, 1),
109                        (3, 1),
110                        (4, 2),
111                        (81, 9),
112                        (80, 8),
113                        (<$type>::max_value(), max_sqrt),
114                        (<$type>::max_value() - 1, max_sqrt),
115                    ];
116                    for &(in_, out) in tests.iter() {
117                        assert_eq!(in_.integer_sqrt(), out, "in {}", in_);
118                    }
119                }
120            )*
121        };
122    }
123
124    gen_tests! {
125        i8 => i8_test,
126        u8 => u8_test,
127        i16 => i16_test,
128        u16 => u16_test,
129        i32 => i32_test,
130        u32 => u32_test,
131        i64 => i64_test,
132        u64 => u64_test,
133        u128 => u128_test,
134        isize => isize_test,
135        usize => usize_test
136    }
137
138    #[test]
139    fn i128_test() {
140        let tests: [(i128, i128); 8] = [
141            (0, 0),
142            (1, 1),
143            (2, 1),
144            (3, 1),
145            (4, 2),
146            (81, 9),
147            (80, 8),
148            (i128::max_value(), 13_043_817_825_332_782_212),
149        ];
150        for &(in_, out) in tests.iter() {
151            assert_eq!(in_.integer_sqrt(), out, "in {}", in_);
152        }
153    }
154}