1use super::BitIterator;
2use super::Bitwise;
3
4pub trait BitArith {
8 type Other: ?Sized;
9
10 fn bit_be_cmp(&self, other: &Self) -> std::cmp::Ordering;
21
22 fn bit_be_add(&mut self, other: &Self::Other) -> bool;
31
32 fn bit_be_sub(&mut self, other: &Self::Other) -> bool;
41
42 fn bit_be_mul(&mut self, other: &Self::Other) -> bool;
51
52 fn bit_be_div(&mut self, other: &Self::Other) -> bool;
62
63 fn bit_be_rem(&mut self, other: &Self::Other) -> bool;
73}
74
75impl BitArith for [u8] {
76 type Other = Self;
77
78 fn bit_be_cmp(&self, other: &Self) -> std::cmp::Ordering {
79 let max_len = std::cmp::max(self.len(), other.len());
80 self.extend_be_iter(max_len)
81 .cmp(other.extend_be_iter(max_len))
82 }
83
84 fn bit_be_add(&mut self, other: &Self) -> bool {
85 self.iter_mut()
86 .rev()
87 .zip(other.iter().rev().chain(std::iter::repeat(&0)))
88 .fold(false, |mut carry, (a, b)| {
89 match (carry, *b) {
90 (true, 0xff) => carry = true,
91 (true, _) => (*a, carry) = a.overflowing_add(b + 1),
92 (false, _) => (*a, carry) = a.overflowing_add(*b),
93 };
94 carry
95 })
96 }
97
98 fn bit_be_sub(&mut self, other: &Self) -> bool {
99 self.iter_mut()
100 .rev()
101 .zip(other.iter().rev().chain(std::iter::repeat(&0)))
102 .fold(false, |mut borrow, (a, b)| {
103 match (borrow, *b) {
104 (true, 0xff) => borrow = true,
105 (true, _) => (*a, borrow) = a.overflowing_sub(b + 1),
106 (false, _) => (*a, borrow) = a.overflowing_sub(*b),
107 };
108 borrow
109 })
110 }
111
112 fn bit_be_mul(&mut self, other: &Self) -> bool {
113 let mut result = vec![0; self.len()];
114 let mut overflow = false;
115 for (i, bit) in other.bit_iter().rev().enumerate() {
116 if bit {
117 let mut multiple = self.to_vec();
118 overflow |= multiple.bit_shl(i);
119 overflow |= result.bit_be_add(&multiple);
120 }
121 }
122 self.copy_from_slice(&result);
123 overflow
124 }
125
126 fn bit_be_div(&mut self, other: &Self) -> bool {
127 if other.iter().all(|&b| b == 0) {
128 return true; }
130
131 let bits_a = self.len() * 8 - self.bit_leading_zeros(); let bits_b = other.len() * 8 - other.bit_leading_zeros(); if bits_a < bits_b {
135 self.fill(0);
136 return false;
137 }
138
139 let mut other = other.extend_be(self.len()); {
141 let common_divisor_bits = self.bit_trailing_zeros().min(other.bit_trailing_zeros());
143 self.bit_shr(common_divisor_bits);
144 other.bit_shr(common_divisor_bits);
145 }
146
147 let n = self.len();
149 let mut result = vec![0; n];
150 let diff = bits_a - bits_b;
151 other.bit_shl(diff);
152 for i in (0..=diff).rev() {
153 if self.bit_be_cmp(&other) != std::cmp::Ordering::Less {
154 self.bit_be_sub(&other);
155 result[n - 1 - i / 8] |= 1 << (i % 8);
156 }
157 other.bit_shr(1);
158 }
159
160 self.copy_from_slice(&result);
161 false
162 }
163
164 fn bit_be_rem(&mut self, other: &Self) -> bool {
165 if other.iter().all(|&b| b == 0) {
166 return true; }
168
169 let bits_a = self.len() * 8 - self.bit_leading_zeros(); let bits_b = other.len() * 8 - other.bit_leading_zeros(); if bits_a < bits_b {
173 return false;
174 }
175
176 let mut other = other.extend_be(self.len()); {
178 let common_divisor_bits = self.bit_trailing_zeros().min(other.bit_trailing_zeros());
180 self.bit_shr(common_divisor_bits);
181 other.bit_shr(common_divisor_bits);
182 }
183
184 let n = self.len();
186 let mut result = vec![0; n];
187 let diff = bits_a - bits_b;
188 other.bit_shl(diff);
189 for i in (0..=diff).rev() {
190 if self.bit_be_cmp(&other) != std::cmp::Ordering::Less {
191 self.bit_be_sub(&other);
192 result[n - 1 - i / 8] |= 1 << (i % 8);
193 }
194 other.bit_shr(1);
195 }
196 false
197 }
198}
199
200trait ByteExtend {
201 fn extend_be_iter(&self, n: usize) -> impl DoubleEndedIterator<Item = &u8>;
202
203 fn extend_be(&self, n: usize) -> Vec<u8>;
204}
205
206impl ByteExtend for [u8] {
207 #[inline(always)]
208 fn extend_be_iter(&self, n: usize) -> impl DoubleEndedIterator<Item = &u8> {
209 std::iter::repeat_n(&0, n - self.len()).chain(self.iter())
210 }
211
212 #[inline(always)]
213 fn extend_be(&self, n: usize) -> Vec<u8> {
214 let mut data = vec![0; n];
215 if self.len() < n {
216 data[n - self.len()..].copy_from_slice(self);
217 } else {
218 assert!(self[..self.len() - n].iter().all(|&b| b == 0));
219 data.copy_from_slice(&self[self.len() - n..]);
220 }
221 data
222 }
223}
224
225#[cfg(test)]
226mod test_arith {
227 use super::*;
228
229 #[test]
230 fn test_bits_add() {
231 let mut a = [0b1111_1111, 0b1111_1111];
232 assert_eq!(a.bit_be_add(&[0b0000_0001]), true);
233 assert_eq!(a, [0b0000_0000, 0b0000_0000]);
234
235 let mut a = [0b0000_0000, 0b0000_0001];
236 assert_eq!(a.bit_be_add(&[0b1111_1111]), false);
237 assert_eq!(a, [0b0000_0001, 0b0000_0000]);
238 }
239
240 #[test]
241 fn test_bits_sub() {
242 let mut a = [0b0000_0000, 0b0000_0001];
243 assert_eq!(a.bit_be_sub(&[0b1111_1111]), true);
244 assert_eq!(a, [0b1111_1111, 0b0000_0010]);
245
246 let mut a = [0b1111_1111, 0b0000_0000];
247 assert_eq!(a.bit_be_sub(&[0b0000_0001]), false);
248 assert_eq!(a, [0b1111_1110, 0b1111_1111]);
249 }
250
251 #[test]
252 fn test_bits_mul() {
253 let mut a = [0xff, 0xff];
254 assert_eq!(a.bit_be_mul(&[0b0000_0010]), true);
255 assert_eq!(a, [0b1111_1111, 0b1111_1110]);
256
257 let mut a = [0b0000_0001, 0b0000_0001];
258 assert_eq!(a.bit_be_mul(&[0b1111_1111]), false);
259 assert_eq!(a, [0b1111_1111, 0b1111_1111]);
260 }
261
262 pub trait BeValue {
263 fn value(&self) -> u64;
264 }
265
266 impl BeValue for [u8] {
267 fn value(&self) -> u64 {
268 let mut bytes = [0; 8];
269 bytes[8 - self.len()..].copy_from_slice(self);
270 u64::from_be_bytes(bytes)
271 }
272 }
273
274 #[test]
275 fn test_bits_div() {
276 const TDATA: &[(&[u8], &[u8], &[u8])] = &[
277 (&[0b0000_1100], &[0b0000_0011], &[0b0000_0100]),
278 (&[0b0011_0000, 0], &[0b0000_1100], &[0b0000_0100, 0]),
279 (&[0b1100_1100], &[0, 0b0000_0011], &[0b0100_0100]),
280 ];
281 for (a, b, c) in TDATA {
282 assert_eq!(a.value() / b.value(), c.value());
283 let mut a = a.to_vec();
284 assert_eq!(a.bit_be_div(b), false);
285 assert_eq!(&a, c);
286 }
287 }
288}