nbits/core/
arith.rs

1use super::BitIterator;
2use super::Bitwise;
3
4/**
5 * Arithmetic operations implementation for `[u8]`
6 */
7pub trait BitArith {
8    type Other: ?Sized;
9
10    /// Comparison for big-endian
11    /// # Examples
12    /// ```
13    /// # use nbits::BitArith;
14    /// # use std::cmp::Ordering;
15    /// assert_eq!([0b0011_0011, 0b0011_0011].bit_be_cmp(&[0b1111_1111]), Ordering::Greater);
16    /// assert_eq!([0b0000_0000, 0b0011_0011].bit_be_cmp(&[0b1111_1111]), Ordering::Less);
17    /// assert_eq!([0b0011_0011, 0b0011_0011].bit_be_cmp(&[0b0000_0000, 0b1111_1111]), Ordering::Greater);
18    /// assert_eq!([0b0011_0011, 0b0011_0011].bit_be_cmp(&[0b1111_1111, 0b0000_0000]), Ordering::Less);
19    /// ```
20    fn bit_be_cmp(&self, other: &Self) -> std::cmp::Ordering;
21
22    /// Bit arithmetic operator `+=` for big-endian
23    /// # Example
24    /// ```
25    /// # use nbits::BitArith;
26    /// let (mut a, b) = ([0b1100_1100, 0b1000_0001], [0b1000_0001]);
27    /// assert_eq!(a.as_mut().bit_be_add(&b), false);
28    /// assert_eq!(a, [0b1100_1101, 0b0000_0010]);
29    /// ```
30    fn bit_be_add(&mut self, other: &Self::Other) -> bool;
31
32    /// Bit arithmetic operator `-=` for big-endian
33    /// # Example
34    /// ```
35    /// # use nbits::BitArith;
36    /// let (mut a, b) = ([0b1100_1100, 0b1000_0001], [0b1000_0001]);
37    /// assert_eq!(a.as_mut().bit_be_sub(&b), false);
38    /// assert_eq!(a, [0b1100_1100, 0b0000_0000]);
39    /// ```
40    fn bit_be_sub(&mut self, other: &Self::Other) -> bool;
41
42    /// Bit arithmetic operator `*=` for big-endian
43    /// # Example
44    /// ```
45    /// # use nbits::BitArith;
46    /// let (mut a, b) = ([0b0011_0000, 0b1000_0001], [0b0000_0010]);
47    /// assert_eq!(a.as_mut().bit_be_mul(&b), false);
48    /// assert_eq!(a, [0b0110_0001, 0b0000_0010]);
49    /// ```
50    fn bit_be_mul(&mut self, other: &Self::Other) -> bool;
51
52    /// Bit arithmetic operator `/=` for big-endian
53    /// # Example
54    /// ```
55    /// # use nbits::BitArith;
56    /// let (a, b) = ([0b1100_0011, 0b0000_0001], [0b1000_0001]);
57    /// let mut x = a.clone();
58    /// x.as_mut().bit_be_div(&b);
59    /// assert_eq!(x, (u16::from_be_bytes(a) / u16::from_be_bytes([0, b[0]])).to_be_bytes());
60    /// ```
61    fn bit_be_div(&mut self, other: &Self::Other) -> bool;
62
63    /// Bit arithmetic operator `%=` for big-endian
64    /// # Example
65    /// ```
66    /// # use nbits::BitArith;
67    /// let (a, b) = ([0b1100_0011, 0b0000_0001], [0b0000_0001, 0b1000_0001]);
68    /// let mut x = a.clone();
69    /// x.as_mut().bit_be_rem(&b);
70    /// // assert_eq!(x, (u16::from_be_bytes(a) % u16::from_be_bytes(b)).to_be_bytes());
71    /// ```
72    fn bit_be_rem(&mut self, other: &Self::Other) -> bool;
73}
74
75impl BitArith for [u8] {
76    type Other = Self;
77
78    fn bit_be_cmp(&self, other: &Self) -> std::cmp::Ordering {
79        let max_len = std::cmp::max(self.len(), other.len());
80        self.extend_be_iter(max_len)
81            .cmp(other.extend_be_iter(max_len))
82    }
83
84    fn bit_be_add(&mut self, other: &Self) -> bool {
85        self.iter_mut()
86            .rev()
87            .zip(other.iter().rev().chain(std::iter::repeat(&0)))
88            .fold(false, |mut carry, (a, b)| {
89                match (carry, *b) {
90                    (true, 0xff) => carry = true,
91                    (true, _) => (*a, carry) = a.overflowing_add(b + 1),
92                    (false, _) => (*a, carry) = a.overflowing_add(*b),
93                };
94                carry
95            })
96    }
97
98    fn bit_be_sub(&mut self, other: &Self) -> bool {
99        self.iter_mut()
100            .rev()
101            .zip(other.iter().rev().chain(std::iter::repeat(&0)))
102            .fold(false, |mut borrow, (a, b)| {
103                match (borrow, *b) {
104                    (true, 0xff) => borrow = true,
105                    (true, _) => (*a, borrow) = a.overflowing_sub(b + 1),
106                    (false, _) => (*a, borrow) = a.overflowing_sub(*b),
107                };
108                borrow
109            })
110    }
111
112    fn bit_be_mul(&mut self, other: &Self) -> bool {
113        let mut result = vec![0; self.len()];
114        let mut overflow = false;
115        for (i, bit) in other.bit_iter().rev().enumerate() {
116            if bit {
117                let mut multiple = self.to_vec();
118                overflow |= multiple.bit_shl(i);
119                overflow |= result.bit_be_add(&multiple);
120            }
121        }
122        self.copy_from_slice(&result);
123        overflow
124    }
125
126    fn bit_be_div(&mut self, other: &Self) -> bool {
127        if other.iter().all(|&b| b == 0) {
128            return true; // Division by zero, return overflow
129        }
130
131        // Ignore leading zeros
132        let bits_a = self.len() * 8 - self.bit_leading_zeros(); // effective bits length
133        let bits_b = other.len() * 8 - other.bit_leading_zeros(); // effective bits length
134        if bits_a < bits_b {
135            self.fill(0);
136            return false;
137        }
138
139        let mut other = other.extend_be(self.len()); // extend to the same length
140        {
141            // Remove common trailing zeros
142            let common_divisor_bits = self.bit_trailing_zeros().min(other.bit_trailing_zeros());
143            self.bit_shr(common_divisor_bits);
144            other.bit_shr(common_divisor_bits);
145        }
146
147        // Perform division
148        let n = self.len();
149        let mut result = vec![0; n];
150        let diff = bits_a - bits_b;
151        other.bit_shl(diff);
152        for i in (0..=diff).rev() {
153            if self.bit_be_cmp(&other) != std::cmp::Ordering::Less {
154                self.bit_be_sub(&other);
155                result[n - 1 - i / 8] |= 1 << (i % 8);
156            }
157            other.bit_shr(1);
158        }
159
160        self.copy_from_slice(&result);
161        false
162    }
163
164    fn bit_be_rem(&mut self, other: &Self) -> bool {
165        if other.iter().all(|&b| b == 0) {
166            return true; // Division by zero, return overflow
167        }
168
169        // Ignore leading zeros
170        let bits_a = self.len() * 8 - self.bit_leading_zeros(); // effective bits length
171        let bits_b = other.len() * 8 - other.bit_leading_zeros(); // effective bits length
172        if bits_a < bits_b {
173            return false;
174        }
175
176        let mut other = other.extend_be(self.len()); // extend to the same length
177        {
178            // Remove common trailing zeros
179            let common_divisor_bits = self.bit_trailing_zeros().min(other.bit_trailing_zeros());
180            self.bit_shr(common_divisor_bits);
181            other.bit_shr(common_divisor_bits);
182        }
183
184        // Perform division
185        let n = self.len();
186        let mut result = vec![0; n];
187        let diff = bits_a - bits_b;
188        other.bit_shl(diff);
189        for i in (0..=diff).rev() {
190            if self.bit_be_cmp(&other) != std::cmp::Ordering::Less {
191                self.bit_be_sub(&other);
192                result[n - 1 - i / 8] |= 1 << (i % 8);
193            }
194            other.bit_shr(1);
195        }
196        false
197    }
198}
199
200trait ByteExtend {
201    fn extend_be_iter(&self, n: usize) -> impl DoubleEndedIterator<Item = &u8>;
202
203    fn extend_be(&self, n: usize) -> Vec<u8>;
204}
205
206impl ByteExtend for [u8] {
207    #[inline(always)]
208    fn extend_be_iter(&self, n: usize) -> impl DoubleEndedIterator<Item = &u8> {
209        std::iter::repeat_n(&0, n - self.len()).chain(self.iter())
210    }
211
212    #[inline(always)]
213    fn extend_be(&self, n: usize) -> Vec<u8> {
214        let mut data = vec![0; n];
215        if self.len() < n {
216            data[n - self.len()..].copy_from_slice(self);
217        } else {
218            assert!(self[..self.len() - n].iter().all(|&b| b == 0));
219            data.copy_from_slice(&self[self.len() - n..]);
220        }
221        data
222    }
223}
224
225#[cfg(test)]
226mod test_arith {
227    use super::*;
228
229    #[test]
230    fn test_bits_add() {
231        let mut a = [0b1111_1111, 0b1111_1111];
232        assert_eq!(a.bit_be_add(&[0b0000_0001]), true);
233        assert_eq!(a, [0b0000_0000, 0b0000_0000]);
234
235        let mut a = [0b0000_0000, 0b0000_0001];
236        assert_eq!(a.bit_be_add(&[0b1111_1111]), false);
237        assert_eq!(a, [0b0000_0001, 0b0000_0000]);
238    }
239
240    #[test]
241    fn test_bits_sub() {
242        let mut a = [0b0000_0000, 0b0000_0001];
243        assert_eq!(a.bit_be_sub(&[0b1111_1111]), true);
244        assert_eq!(a, [0b1111_1111, 0b0000_0010]);
245
246        let mut a = [0b1111_1111, 0b0000_0000];
247        assert_eq!(a.bit_be_sub(&[0b0000_0001]), false);
248        assert_eq!(a, [0b1111_1110, 0b1111_1111]);
249    }
250
251    #[test]
252    fn test_bits_mul() {
253        let mut a = [0xff, 0xff];
254        assert_eq!(a.bit_be_mul(&[0b0000_0010]), true);
255        assert_eq!(a, [0b1111_1111, 0b1111_1110]);
256
257        let mut a = [0b0000_0001, 0b0000_0001];
258        assert_eq!(a.bit_be_mul(&[0b1111_1111]), false);
259        assert_eq!(a, [0b1111_1111, 0b1111_1111]);
260    }
261
262    pub trait BeValue {
263        fn value(&self) -> u64;
264    }
265
266    impl BeValue for [u8] {
267        fn value(&self) -> u64 {
268            let mut bytes = [0; 8];
269            bytes[8 - self.len()..].copy_from_slice(self);
270            u64::from_be_bytes(bytes)
271        }
272    }
273
274    #[test]
275    fn test_bits_div() {
276        const TDATA: &[(&[u8], &[u8], &[u8])] = &[
277            (&[0b0000_1100], &[0b0000_0011], &[0b0000_0100]),
278            (&[0b0011_0000, 0], &[0b0000_1100], &[0b0000_0100, 0]),
279            (&[0b1100_1100], &[0, 0b0000_0011], &[0b0100_0100]),
280        ];
281        for (a, b, c) in TDATA {
282            assert_eq!(a.value() / b.value(), c.value());
283            let mut a = a.to_vec();
284            assert_eq!(a.bit_be_div(b), false);
285            assert_eq!(&a, c);
286        }
287    }
288}