Skip to main content

baa/bv/
arithmetic.rs

1// Copyright 2023-2024 The Regents of the University of California
2// Copyright 2024 Cornell University
3// released under BSD 3-Clause License
4// author: Kevin Laeufer <laeufer@cornell.edu>
5//
6// basic arithmetic implementations
7
8use crate::{WidthInt, Word};
9use std::cmp::Ordering;
10
11// TODO: make sure this is updated together with the Word type
12type DoubleWord = u128;
13
14#[inline]
15pub fn mask(bits: WidthInt) -> Word {
16    if bits == Word::BITS || bits == 0 {
17        Word::MAX
18    } else {
19        assert!(bits < Word::BITS);
20        ((1 as Word) << bits) - 1
21    }
22}
23
24#[inline]
25pub fn mask_double_word(bits: WidthInt) -> DoubleWord {
26    if bits == DoubleWord::BITS || bits == 0 {
27        DoubleWord::MAX
28    } else {
29        assert!(bits < DoubleWord::BITS);
30        ((1 as DoubleWord) << bits) - 1
31    }
32}
33
34#[inline]
35pub(crate) fn clear(dst: &mut [Word]) {
36    for w in dst.iter_mut() {
37        *w = 0;
38    }
39}
40
41#[inline]
42fn set(dst: &mut [Word]) {
43    for w in dst.iter_mut() {
44        *w = Word::MAX;
45    }
46}
47
48#[inline]
49pub(crate) fn assign(dst: &mut [Word], source: &[Word]) {
50    for (d, s) in dst.iter_mut().zip(source.iter()) {
51        *d = *s;
52    }
53}
54
55#[inline]
56pub(crate) fn zero_extend(dst: &mut [Word], source: &[Word]) {
57    // copy source to dst
58    assign(dst, source);
59    // zero out remaining words
60    clear(&mut dst[source.len()..]);
61}
62
63#[inline]
64pub(crate) fn sign_extend(
65    dst: &mut [Word],
66    source: &[Word],
67    src_width: WidthInt,
68    dst_width: WidthInt,
69) {
70    // copy source to dst
71    assign(dst, source);
72    if is_neg(source, src_width) {
73        // set source msbs in destination
74        let lsbs_in_msb = src_width % Word::BITS;
75        if lsbs_in_msb > 0 {
76            let msbs_in_msb = Word::BITS - lsbs_in_msb;
77            dst[source.len() - 1] |= mask(msbs_in_msb) << lsbs_in_msb;
78        }
79        // set other dst bytes to all 1s
80        set(&mut dst[source.len()..]);
81        // clear destination msbs
82        mask_msb(dst, dst_width);
83    } else {
84        clear(&mut dst[source.len()..]);
85    }
86}
87
88#[inline]
89pub(crate) fn mask_msb(dst: &mut [Word], width: WidthInt) {
90    debug_assert_eq!(width.div_ceil(Word::BITS) as usize, dst.len());
91    let m = mask(width % Word::BITS);
92    *dst.last_mut().unwrap() &= m;
93}
94
95#[inline]
96pub(crate) fn is_bit_set(source: &[Word], pos: WidthInt) -> bool {
97    let bit_idx = pos % Word::BITS;
98    let word_idx = (pos / Word::BITS) as usize;
99    (source[word_idx] >> bit_idx) & 1 == 1
100}
101
102#[inline]
103pub(crate) fn set_bit(dst: &mut [Word], pos: WidthInt) {
104    let bit_idx = pos % Word::BITS;
105    let word_idx = (pos / Word::BITS) as usize;
106    dst[word_idx] |= 1 << bit_idx;
107}
108
109#[inline]
110pub(crate) fn clear_bit(dst: &mut [Word], pos: WidthInt) {
111    let bit_idx = pos % Word::BITS;
112    let word_idx = (pos / Word::BITS) as usize;
113    dst[word_idx] &= !(1 << bit_idx);
114}
115
116#[inline]
117pub(crate) fn slice(dst: &mut [Word], source: &[Word], hi: WidthInt, lo: WidthInt) {
118    let lo_offset = lo % Word::BITS;
119    let hi_word = (hi / Word::BITS) as usize;
120    let lo_word = (lo / Word::BITS) as usize;
121    let src = &source[lo_word..(hi_word + 1)];
122
123    let shift_right = lo_offset;
124    if shift_right == 0 {
125        assign(dst, src);
126    } else {
127        // assign with a shift
128        let shift_left = Word::BITS - shift_right;
129        let m = mask(shift_right);
130        let mut prev = src[0] >> shift_right;
131        // We append a zero to the src iter in case src.len() == dst.len().
132        // If src.len() == dst.len() + 1, then the 0 will just be ignored by `zip`.
133        for (d, s) in dst.iter_mut().zip(src.iter().skip(1).chain([0].iter())) {
134            *d = prev | ((*s) & m) << shift_left;
135            prev = (*s) >> shift_right;
136        }
137    }
138    // mask the result msb
139    mask_msb(dst, hi - lo + 1);
140}
141
142#[inline]
143pub(crate) fn concat(dst: &mut [Word], msb: &[Word], lsb: &[Word], lsb_width: WidthInt) {
144    // copy lsb to dst
145    assign(dst, lsb);
146
147    let lsb_offset = lsb_width % Word::BITS;
148    if lsb_offset == 0 {
149        // copy msb to dst
150        for (d, m) in dst.iter_mut().skip(lsb.len()).zip(msb.iter()) {
151            *d = *m;
152        }
153    } else {
154        // copy a shifted version of the msb to dst
155        let shift_right = Word::BITS - lsb_offset;
156        let m = mask(shift_right);
157        let mut prev = dst[lsb.len() - 1]; // the msb of the lsb
158        for (d, s) in dst
159            .iter_mut()
160            .skip(lsb.len() - 1)
161            .zip(msb.iter().chain([0].iter()))
162        {
163            *d = prev | ((*s) & m) << lsb_offset;
164            prev = (*s) >> shift_right;
165        }
166    }
167}
168
169#[inline]
170pub(crate) fn not(dst: &mut [Word], source: &[Word], width: WidthInt) {
171    bitwise_un_op(dst, source, |e| !e);
172    mask_msb(dst, width);
173}
174
175#[inline]
176fn bitwise_un_op(dst: &mut [Word], source: &[Word], op: fn(Word) -> Word) {
177    for (d, s) in dst.iter_mut().zip(source.iter()) {
178        *d = (op)(*s);
179    }
180}
181
182#[inline]
183pub(crate) fn and(dst: &mut [Word], a: &[Word], b: &[Word]) {
184    bitwise_bin_op(dst, a, b, |a, b| a & b)
185}
186
187#[inline]
188pub(crate) fn or(dst: &mut [Word], a: &[Word], b: &[Word]) {
189    bitwise_bin_op(dst, a, b, |a, b| a | b)
190}
191
192#[inline]
193pub(crate) fn xor(dst: &mut [Word], a: &[Word], b: &[Word]) {
194    bitwise_bin_op(dst, a, b, |a, b| a ^ b)
195}
196
197#[inline]
198fn bitwise_bin_op(dst: &mut [Word], a: &[Word], b: &[Word], op: fn(Word, Word) -> Word) {
199    for (d, (a, b)) in dst.iter_mut().zip(a.iter().zip(b.iter())) {
200        *d = (op)(*a, *b);
201    }
202}
203
204#[inline]
205fn adc(dst: &mut Word, carry: u8, a: Word, b: Word) -> u8 {
206    let sum = carry as DoubleWord + a as DoubleWord + b as DoubleWord;
207    let new_carry = (sum >> Word::BITS) as u8;
208    *dst = sum as Word;
209    new_carry
210}
211
212/// Add function inspired by the num-bigint implementation: https://docs.rs/num-bigint/0.4.4/src/num_bigint/biguint/addition.rs.html
213#[inline]
214pub(crate) fn add(dst: &mut [Word], a: &[Word], b: &[Word], width: WidthInt) {
215    let mut carry = 0;
216    for (dd, (aa, bb)) in dst.iter_mut().zip(a.iter().zip(b.iter())) {
217        carry = adc(dd, carry, *aa, *bb);
218    }
219    mask_msb(dst, width);
220}
221
222/// Sub function inspired by the num-bigint implementation: https://docs.rs/num-bigint/0.4.4/src/num_bigint/biguint/subtraction.rs.html
223#[inline]
224pub(crate) fn sub(dst: &mut [Word], a: &[Word], b: &[Word], width: WidthInt) {
225    // we add one by setting the input carry to one
226    let mut carry = 1;
227    for (dd, (aa, bb)) in dst.iter_mut().zip(a.iter().zip(b.iter())) {
228        // we invert b which in addition to adding 1 turns it into `-b`
229        carry = adc(dd, carry, *aa, !(*bb));
230    }
231    mask_msb(dst, width);
232}
233
234/// Mul function inspired by the num-bigint implementation: https://docs.rs/num-bigint/0.4.4/src/num_bigint/biguint/multiplication.rs.html
235#[inline]
236pub(crate) fn mul(dst: &mut [Word], a: &[Word], b: &[Word], width: WidthInt) {
237    if width <= Word::BITS {
238        let (res, _) = a[0].overflowing_mul(b[0]);
239        dst[0] = res & mask(width);
240    } else {
241        todo!(
242            "implement multiplication for bit vectors larger {}",
243            Word::BITS
244        );
245    }
246}
247
248/// Multiplies `dst` with the word `value`. Does not mask the MSB since this is an internal subroutine.
249#[inline]
250pub(crate) fn mul_word(dst: &mut [Word], value: Word) {
251    let mut carry = 0;
252    for w in dst.iter_mut() {
253        let res = *w as DoubleWord * value as DoubleWord + carry as DoubleWord;
254        carry = (res >> Word::BITS) as Word;
255        *w = (res & Word::MAX as DoubleWord) as Word;
256    }
257}
258
259/// Adds `value` to  `dst`. Does not mask the MSB since this is an internal subroutine.
260#[inline]
261pub(crate) fn add_word(dst: &mut [Word], value: Word) {
262    let mut carry = 0;
263    for (ii, w) in dst.iter_mut().enumerate() {
264        let aa = *w;
265        let bb = if ii == 0 { value } else { 0 };
266        carry = adc(w, carry, aa, bb);
267    }
268}
269
270#[inline]
271pub(crate) fn shift_right(
272    dst: &mut [Word],
273    a: &[Word],
274    b: &[Word],
275    width: WidthInt,
276) -> Option<WidthInt> {
277    // clear the destination
278    clear(dst);
279
280    // check to see if we are shifting for more than our width
281    let shift_amount = get_shift_amount(b, width)?;
282
283    // otherwise we actually perform the shift by converting it to a slice
284    let hi = width - 1;
285    let lo = shift_amount;
286    let result_width = hi - lo + 1;
287    let result_words = result_width.div_ceil(Word::BITS) as usize;
288    slice(&mut dst[..result_words], a, hi, lo);
289    Some(shift_amount)
290}
291
292#[inline]
293pub(crate) fn arithmetic_shift_right(dst: &mut [Word], a: &[Word], b: &[Word], width: WidthInt) {
294    // perform shift
295    let shift_amount = shift_right(dst, a, b, width);
296
297    // pad with sign bit if necessary
298    if is_neg(a, width) {
299        match shift_amount {
300            None => {
301                // over shift => we just need to set everything to 1
302                for d in dst.iter_mut() {
303                    *d = Word::MAX;
304                }
305                mask_msb(dst, width);
306            }
307            Some(amount) => {
308                if amount > 0 {
309                    let res_width = width - amount;
310                    let local_msb = (res_width - 1) % Word::BITS;
311                    let msb_word = ((res_width - 1) / Word::BITS) as usize;
312                    if local_msb < (Word::BITS - 1) {
313                        let msb_word_mask = mask(Word::BITS - (local_msb + 1));
314                        dst[msb_word] |= msb_word_mask << (local_msb + 1);
315                    }
316                    for d in dst[(msb_word + 1)..].iter_mut() {
317                        *d = Word::MAX;
318                    }
319                    mask_msb(dst, width);
320                }
321            }
322        }
323    }
324}
325
326#[inline]
327pub(crate) fn shift_left(dst: &mut [Word], a: &[Word], b: &[Word], width: WidthInt) {
328    // check to see if we are shifting for more than our width
329    let shift_amount = match get_shift_amount(b, width) {
330        None => {
331            clear(dst);
332            return;
333        }
334        Some(value) => value,
335    };
336
337    // otherwise we actually perform the shift
338    let shift_left = shift_amount % Word::BITS;
339    let shift_words = shift_amount / Word::BITS;
340    let shift_right = Word::BITS - shift_left;
341    let zeros = std::iter::repeat_n(&(0 as Word), shift_words as usize);
342    let mut prev = 0;
343    for (d, s) in dst.iter_mut().zip(zeros.chain(a.iter())) {
344        if shift_left == 0 {
345            *d = *s;
346        } else {
347            *d = (*s << shift_left) | prev;
348            prev = *s >> shift_right;
349        }
350    }
351    if shift_left > 0 {
352        mask_msb(dst, width);
353    }
354}
355
356#[inline]
357fn get_shift_amount(b: &[Word], width: WidthInt) -> Option<WidthInt> {
358    let msb_set = b.iter().skip(1).any(|w| *w != 0);
359    let shift_amount = b[0];
360    if msb_set || shift_amount >= width as Word {
361        None // result is just zero or the sign bit
362    } else {
363        Some(shift_amount as WidthInt)
364    }
365}
366
367#[inline]
368pub(crate) fn negate(dst: &mut [Word], b: &[Word], width: WidthInt) {
369    dst.clone_from_slice(b);
370    negate_in_place(dst, width);
371}
372
373#[inline]
374pub(crate) fn negate_in_place(dst: &mut [Word], width: WidthInt) {
375    // we add one by setting the input carry to one
376    let mut carry = 1;
377    for dd in dst.iter_mut() {
378        // we invert b which in addition to adding 1 turns it into `-b`
379        let b = !(*dd);
380        carry = adc(dd, carry, 0, b);
381    }
382    mask_msb(dst, width);
383}
384
385#[inline]
386pub(crate) fn cmp_equal(a: &[Word], b: &[Word]) -> bool {
387    a.iter().zip(b.iter()).all(|(a, b)| a == b)
388}
389
390#[inline]
391pub(crate) fn cmp_greater(a: &[Word], b: &[Word]) -> bool {
392    is_greater_and_not_less(a, b).unwrap_or(false)
393}
394
395#[inline]
396pub(crate) fn is_neg(src: &[Word], width: WidthInt) -> bool {
397    let msb_bit_id = (width - 1) % Word::BITS;
398    let msb_word = src.last().unwrap();
399    let msb_bit_value = ((msb_word) >> msb_bit_id) & 1;
400    msb_bit_value == 1
401}
402
403#[inline]
404pub(crate) fn is_pow2(words: &[Word]) -> Option<WidthInt> {
405    // find most significant bit set
406    let mut bit_pos = None;
407    for (word_ii, &word) in words.iter().enumerate() {
408        if bit_pos.is_none() {
409            if word != 0 {
410                // is there only one bit set?
411                if word.leading_zeros() + word.trailing_zeros() == Word::BITS - 1 {
412                    bit_pos = Some(word.trailing_zeros() + word_ii as WidthInt * Word::BITS);
413                } else {
414                    // more than one bit set
415                    return None;
416                }
417            }
418        } else if word != 0 {
419            // more than one bit set
420            return None;
421        }
422    }
423    bit_pos
424}
425
426#[inline]
427pub(crate) fn min_width(words: &[Word]) -> WidthInt {
428    // find most significant bit set
429    for (word_ii, &word) in words.iter().enumerate() {
430        if word != 0 {
431            // cannot underflow since word.leading_zeros() is always less than Word::BITS
432            let bit_pos = Word::BITS - word.leading_zeros() - 1;
433            return word_ii as WidthInt * Word::BITS + bit_pos + 1;
434        }
435    }
436    // all words are zero
437    0
438}
439
440#[inline]
441pub(crate) fn cmp_greater_signed(a: &[Word], b: &[Word], width: WidthInt) -> bool {
442    let (is_neg_a, is_neg_b) = (is_neg(a, width), is_neg(b, width));
443    match (is_neg_a, is_neg_b) {
444        (true, false) => false, // -|a| < |b|
445        (false, true) => true,  // |a| > -|b|
446        (false, false) => cmp_greater(a, b),
447        (true, true) => cmp_greater(a, b), // TODO: does this actually work?
448    }
449}
450
451/// `Some(true)` if `a > b`, `Some(false)` if `a < b`, None if `a == b`
452#[inline]
453fn is_greater_and_not_less(a: &[Word], b: &[Word]) -> Option<bool> {
454    for (a, b) in a.iter().rev().zip(b.iter().rev()) {
455        match a.cmp(b) {
456            Ordering::Less => return Some(false),
457            Ordering::Equal => {} // continue
458            Ordering::Greater => return Some(true),
459        }
460    }
461    None
462}
463
464#[inline]
465pub(crate) fn cmp_greater_equal(a: &[Word], b: &[Word]) -> bool {
466    is_greater_and_not_less(a, b).unwrap_or(true)
467}
468
469#[inline]
470pub(crate) fn cmp_greater_equal_signed(a: &[Word], b: &[Word], width: WidthInt) -> bool {
471    match (is_neg(a, width), is_neg(b, width)) {
472        (true, false) => false, // -|a| < |b|
473        (false, true) => true,  // |a| > -|b|
474        (false, false) => cmp_greater_equal(a, b),
475        (true, true) => cmp_greater_equal(a, b), // TODO: does this actually work?
476    }
477}
478
479#[inline]
480pub(crate) fn word_to_bool(value: Word) -> bool {
481    (value & 1) == 1
482}
483
484#[cfg(test)]
485pub(crate) fn assert_unused_bits_zero(value: &[Word], width: WidthInt) {
486    let offset = width % Word::BITS;
487    if offset > 0 {
488        let msb = *value.last().unwrap();
489        let m = !mask(offset);
490        let unused = msb & m;
491        assert_eq!(unused, 0, "unused msb bits need to be zero!")
492    }
493}
494
495pub(crate) fn find_ranges_of_ones(words: &[Word]) -> Vec<std::ops::Range<WidthInt>> {
496    // the actual width does not matter since we assume that all unused bits in the msb are set to zero
497    let mut out = vec![];
498    let mut range_start: Option<WidthInt> = None;
499    for (word_ii, word) in words.iter().enumerate() {
500        let lsb_ii = word_ii as WidthInt * Word::BITS;
501        let mut word = *word;
502        let mut bits_consumed = 0;
503
504        // handle open range from previous word
505        if let Some(start) = range_start {
506            let ones = word.trailing_ones();
507            bits_consumed += ones;
508            word >>= ones;
509            if ones < Word::BITS {
510                range_start = None;
511                out.push(start..lsb_ii + bits_consumed);
512            }
513        }
514
515        // find ranges in this word
516        while bits_consumed < Word::BITS {
517            debug_assert!(range_start.is_none());
518            if word == 0 {
519                // done
520                bits_consumed = Word::BITS;
521            } else {
522                let zeros = word.trailing_zeros();
523                bits_consumed += zeros;
524                word >>= zeros;
525                let start = bits_consumed;
526                let ones = word.trailing_ones();
527                bits_consumed += ones;
528                word = word.overflowing_shr(ones).0;
529                match bits_consumed.cmp(&Word::BITS) {
530                    Ordering::Less => {
531                        let end = bits_consumed;
532                        out.push(lsb_ii + start..lsb_ii + end);
533                    }
534                    Ordering::Equal => {
535                        // done, range might expand to next word
536                        range_start = Some(start + lsb_ii);
537                    }
538                    Ordering::Greater => {
539                        unreachable!("")
540                    }
541                }
542            }
543        }
544    }
545    // finish open range
546    if let Some(start) = range_start {
547        let end = words.len() as WidthInt * Word::BITS;
548        out.push(start..end);
549    }
550
551    out
552}