core_models/abstractions/
bitvec.rs

1//! This module provides a specification-friendly bit vector type.
2use super::bit::{Bit, MachineInteger};
3use super::funarr::*;
4
5use std::fmt::Formatter;
6
7// This is required due to some hax-lib inconsistencies with versus without `cfg(hax)`.
8#[cfg(hax)]
9use hax_lib::{int, ToInt};
10
11// TODO: this module uses `u128/i128` as mathematic integers. We should use `hax_lib::int` or bigint.
12
13/// A fixed-size bit vector type.
14///
15/// `BitVec<N>` is a specification-friendly, fixed-length bit vector that internally
16/// stores an array of [`Bit`] values, where each `Bit` represents a single binary digit (0 or 1).
17///
18/// This type provides several utility methods for constructing and converting bit vectors:
19///
20/// The [`Debug`] implementation for `BitVec` pretty-prints the bits in groups of eight,
21/// making the bit pattern more human-readable. The type also implements indexing,
22/// allowing for easy access to individual bits.
23#[hax_lib::fstar::before("noeq")]
24#[derive(Copy, Clone, Eq, PartialEq)]
25pub struct BitVec<const N: u64>(FunArray<N, Bit>);
26
27/// Pretty prints a bit slice by group of 8
28#[hax_lib::exclude]
29fn bit_slice_to_string(bits: &[Bit]) -> String {
30    bits.iter()
31        .map(|bit| match bit {
32            Bit::Zero => '0',
33            Bit::One => '1',
34        })
35        .collect::<Vec<_>>()
36        .chunks(8)
37        .map(|bits| bits.iter().collect::<String>())
38        .map(|s| format!("{s} "))
39        .collect::<String>()
40        .trim()
41        .into()
42}
43
44#[hax_lib::exclude]
45impl<const N: u64> core::fmt::Debug for BitVec<N> {
46    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
47        write!(f, "{}", bit_slice_to_string(&self.0.as_vec()))
48    }
49}
50
51#[hax_lib::attributes]
52impl<const N: u64> core::ops::Index<u64> for BitVec<N> {
53    type Output = Bit;
54    #[requires(index < N)]
55    fn index(&self, index: u64) -> &Self::Output {
56        self.0.get(index)
57    }
58}
59
60/// Convert a bit slice into an unsigned number.
61#[hax_lib::exclude]
62fn u128_int_from_bit_slice(bits: &[Bit]) -> u128 {
63    bits.iter()
64        .enumerate()
65        .map(|(i, bit)| u128::from(*bit) << i)
66        .sum::<u128>()
67}
68
69/// Convert a bit slice into a machine integer of type `T`.
70#[hax_lib::exclude]
71fn int_from_bit_slice<T: TryFrom<i128> + MachineInteger + Copy>(bits: &[Bit]) -> T {
72    debug_assert!(bits.len() <= T::bits() as usize);
73    let result = if T::SIGNED {
74        let is_negative = matches!(bits[T::bits() as usize - 1], Bit::One);
75        let s = u128_int_from_bit_slice(&bits[0..T::bits() as usize - 1]) as i128;
76        if is_negative {
77            s + (-2i128).pow(T::bits() - 1)
78        } else {
79            s
80        }
81    } else {
82        u128_int_from_bit_slice(bits) as i128
83    };
84    let Ok(n) = result.try_into() else {
85        // Conversion must succeed as `result` is guaranteed to be in range due to the bit-length check.
86        unreachable!()
87    };
88    n
89}
90
91#[hax_lib::fstar::replace(
92    r#"
93let ${BitVec::<0>::from_fn::<fn(u64)->Bit>}
94    (v_N: u64)
95    (f: (i: u64 {v i < v v_N}) -> $:{Bit})
96    : t_BitVec v_N = 
97    ${BitVec::<0>}(${FunArray::<0,()>::from_fn::<fn(u64)->()>} v_N f)
98"#
99)]
100const _: () = ();
101
102macro_rules! impl_pointwise {
103    ($n:literal, $($i:literal)*) => {
104        impl BitVec<$n> {
105            pub fn pointwise(self) -> Self {
106                Self::from_fn(|i| match i {
107                    $($i => self[$i],)*
108                    _ => unreachable!(),
109                })
110            }
111        }
112    };
113}
114
115impl_pointwise!(128, 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127);
116impl_pointwise!(256, 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255);
117
118/// An F* attribute that indiquates a rewritting lemma should be applied
119pub const REWRITE_RULE: () = {};
120
121#[hax_lib::exclude]
122impl<const N: u64> BitVec<N> {
123    /// Constructor for BitVec. `BitVec::<N>::from_fn` constructs a bitvector out of a function that takes usizes smaller than `N` and produces bits.
124    pub fn from_fn<F: Fn(u64) -> Bit>(f: F) -> Self {
125        Self(FunArray::from_fn(f))
126    }
127    /// Convert a slice of machine integers where only the `d` least significant bits are relevant.
128    pub fn from_slice<T: Into<i128> + MachineInteger + Copy>(x: &[T], d: u64) -> Self {
129        Self::from_fn(|i| Bit::of_int::<T>(x[(i / d) as usize], (i % d) as u32))
130    }
131
132    /// Construct a BitVec out of a machine integer.
133    pub fn from_int<T: Into<i128> + MachineInteger + Copy>(n: T) -> Self {
134        Self::from_slice::<T>(&[n], T::bits() as u64)
135    }
136
137    /// Convert a BitVec into a machine integer of type `T`.
138    pub fn to_int<T: TryFrom<i128> + MachineInteger + Copy>(self) -> T {
139        int_from_bit_slice(&self.0.as_vec())
140    }
141
142    /// Convert a BitVec into a vector of machine integers of type `T`.
143    pub fn to_vec<T: TryFrom<i128> + MachineInteger + Copy>(&self) -> Vec<T> {
144        self.0
145            .as_vec()
146            .chunks(T::bits() as usize)
147            .map(int_from_bit_slice)
148            .collect()
149    }
150
151    /// Generate a random BitVec.
152    pub fn rand() -> Self {
153        use rand::prelude::*;
154        let random_source: Vec<_> = {
155            let mut rng = rand::rng();
156            (0..N).map(|_| rng.random::<bool>()).collect()
157        };
158        Self::from_fn(|i| random_source[i as usize].into())
159    }
160}
161
162#[hax_lib::fstar::replace(
163    r#"
164open FStar.FunctionalExtensionality
165
166let extensionality' (#a: Type) (#b: Type) (f g: FStar.FunctionalExtensionality.(a ^-> b))
167  : Lemma (ensures (FStar.FunctionalExtensionality.feq f g <==> f == g))
168  = ()
169
170let mark_to_normalize #t (x: t): t = x
171
172open FStar.Tactics.V2
173#push-options "--z3rlimit 80 --admit_smt_queries true"
174let bitvec_rewrite_lemma_128 (x: $:{BitVec<128>})
175: Lemma (x == mark_to_normalize (${BitVec::<128>::pointwise} x)) =
176    let a = x._0._0 in
177    let b = (${BitVec::<128>::pointwise} x)._0._0 in
178    assert_norm (FStar.FunctionalExtensionality.feq a b);
179    extensionality' a b
180
181let bitvec_rewrite_lemma_256 (x: $:{BitVec<256>})
182: Lemma (x == mark_to_normalize (${BitVec::<256>::pointwise} x)) =
183    let a = x._0._0 in
184    let b = (${BitVec::<256>::pointwise} x)._0._0 in
185    assert_norm (FStar.FunctionalExtensionality.feq a b);
186    extensionality' a b
187#pop-options
188
189let bitvec_postprocess_norm_aux (): Tac unit = with_compat_pre_core 1 (fun () ->
190    let debug_mode = ext_enabled "debug_bv_postprocess_rewrite" in
191    let crate = match cur_module () with | crate::_ -> crate | _ -> fail "Empty module name" in
192    // Remove indirections
193    norm [primops; iota; delta_namespace [crate; "Libcrux_intrinsics"]; zeta_full];
194    // Rewrite call chains
195    let lemmas = FStar.List.Tot.map (fun f -> pack_ln (FStar.Stubs.Reflection.V2.Data.Tv_FVar f)) (lookup_attr (`${REWRITE_RULE}) (top_env ())) in
196    l_to_r lemmas;
197    /// Get rid of casts
198    norm [primops; iota; delta_namespace ["Rust_primitives"; "Prims.pow2"]; zeta_full];
199    if debug_mode then print ("[postprocess_rewrite_helper] lemmas = " ^ term_to_string (quote lemmas));
200
201    l_to_r [`bitvec_rewrite_lemma_128; `bitvec_rewrite_lemma_256];
202
203    let round _: Tac unit =
204        if debug_mode then dump "[postprocess_rewrite_helper] Rewrote goal";
205        // Normalize as much as possible
206        norm [primops; iota; delta_namespace ["Core"; crate; "Core_models"; "Libcrux_intrinsics"; "FStar.FunctionalExtensionality"; "Rust_primitives"]; zeta_full];
207        if debug_mode then print ("[postprocess_rewrite_helper] first norm done");
208        // Compute the last bits
209        // compute ();
210        // if debug_mode then dump ("[postprocess_rewrite_helper] compute done");
211        // Force full normalization
212        norm [primops; iota; delta; unascribe; zeta_full];
213        if debug_mode then dump "[postprocess_rewrite_helper] after full normalization";
214        // Solves the goal `<normalized body> == ?u`
215        trefl ()
216    in
217
218    ctrl_rewrite BottomUp (fun t ->
219        let f, args = collect_app t in
220        let matches = match inspect f with | Tv_UInst f _ | Tv_FVar f -> (inspect_fv f) = explode_qn (`%mark_to_normalize) | _ -> false in
221        let has_two_args = match args with | [_; _] -> true | _ -> false in
222        (matches && has_two_args, Continue)
223    ) round;
224
225    // Solves the goal `<normalized body> == ?u`
226    trefl ()
227)
228
229let ${bitvec_postprocess_norm} (): Tac unit =
230    if lax_on ()
231    then trefl () // don't bother rewritting the goal
232    else bitvec_postprocess_norm_aux ()
233"#
234)]
235/// This function is useful only for verification in F*.
236/// Used with `postprocess_rewrite`, this tactic:
237///  1. Applies a series of rewrite rules (the lemmas marked with `REWRITE_RULE`)
238///  2. Normalizes, bottom-up, every sub-expressions typed `BitVec<_>` inside the body of a function.
239/// This tactic should be used on expressions that compute a _static_ permutation of bits.
240pub fn bitvec_postprocess_norm() {}
241
242#[hax_lib::attributes]
243impl<const N: u64> BitVec<N> {
244    #[hax_lib::requires(CHUNK > 0 && CHUNK.to_int() * SHIFTS.to_int() == N.to_int())]
245    pub fn chunked_shift<const CHUNK: u64, const SHIFTS: u64>(
246        self,
247        shl: FunArray<SHIFTS, i128>,
248    ) -> BitVec<N> {
249        // TODO: this inner method is because of https://github.com/cryspen/hax-evit/issues/29
250        #[hax_lib::fstar::options("--z3rlimit 50 --split_queries always")]
251        #[hax_lib::requires(CHUNK > 0 && CHUNK.to_int() * SHIFTS.to_int() == N.to_int())]
252        fn chunked_shift<const N: u64, const CHUNK: u64, const SHIFTS: u64>(
253            bitvec: BitVec<N>,
254            shl: FunArray<SHIFTS, i128>,
255        ) -> BitVec<N> {
256            BitVec::from_fn(|i| {
257                let nth_bit = i % CHUNK;
258                let nth_chunk = i / CHUNK;
259                hax_lib::assert_prop!(nth_chunk.to_int() <= SHIFTS.to_int() - int!(1));
260                hax_lib::assert_prop!(
261                    nth_chunk.to_int() * CHUNK.to_int()
262                        <= (SHIFTS.to_int() - int!(1)) * CHUNK.to_int()
263                );
264                let shift: i128 = if nth_chunk < SHIFTS {
265                    shl[nth_chunk]
266                } else {
267                    0
268                };
269                let local_index = (nth_bit as i128).wrapping_sub(shift);
270                if local_index < CHUNK as i128 && local_index >= 0 {
271                    let local_index = local_index as u64;
272                    hax_lib::assert_prop!(
273                        nth_chunk.to_int() * CHUNK.to_int() + local_index.to_int()
274                            < SHIFTS.to_int() * CHUNK.to_int()
275                    );
276                    bitvec[nth_chunk * CHUNK + local_index]
277                } else {
278                    Bit::Zero
279                }
280            })
281        }
282        chunked_shift::<N, CHUNK, SHIFTS>(self, shl)
283    }
284
285    /// Folds over the array, accumulating a result.
286    ///
287    /// # Arguments
288    /// * `init` - The initial value of the accumulator.
289    /// * `f` - A function combining the accumulator and each element.
290    pub fn fold<A>(&self, init: A, f: fn(A, Bit) -> A) -> A {
291        self.0.fold(init, f)
292    }
293}
294
295pub mod int_vec_interp {
296    //! This module defines interpretation for bit vectors as vectors of machine integers of various size and signedness.
297    use super::*;
298
299    /// An F* attribute that marks an item as being an interpretation lemma.
300    #[allow(dead_code)]
301    #[hax_lib::fstar::before("irreducible")]
302    pub const SIMPLIFICATION_LEMMA: () = ();
303
304    /// Derives interpretations functions, simplification lemmas and type
305    /// synonyms.
306    macro_rules! interpretations {
307        ($n:literal; $($name:ident [$ty:ty; $m:literal]),*) => {
308            $(
309                #[doc = concat!(stringify!($ty), " vectors of size ", stringify!($m))]
310                #[allow(non_camel_case_types)]
311                pub type $name = FunArray<$m, $ty>;
312                pastey::paste! {
313                    const _: ()  = {
314                        #[hax_lib::opaque]
315                        impl BitVec<$n> {
316                            #[doc = concat!("Conversion from ", stringify!($ty), " vectors of size ", stringify!($m), "to  bit vectors of size ", stringify!($n))]
317                            pub fn [< from_ $name >](iv: $name) -> BitVec<$n> {
318                                let vec: Vec<$ty> = iv.as_vec();
319                                Self::from_slice(&vec[..], <$ty>::bits() as u64)
320                            }
321                            #[doc = concat!("Conversion from bit vectors of size ", stringify!($n), " to ", stringify!($ty), " vectors of size ", stringify!($m))]
322                            pub fn [< to_ $name >](bv: BitVec<$n>) -> $name {
323                                let vec: Vec<$ty> = bv.to_vec();
324                                $name::from_fn(|i| vec[i as usize])
325                            }
326                        }
327
328                        #[cfg(test)]
329                        impl From<BitVec<$n>> for $name {
330                            fn from(bv: BitVec<$n>) -> Self {
331                                BitVec::[< to_ $name >](bv)
332                            }
333                        }
334                        #[cfg(test)]
335                        impl From<$name> for BitVec<$n> {
336                            fn from(iv: $name) -> Self {
337                                BitVec::[< from_ $name >](iv)
338                            }
339                        }
340
341                        #[doc = concat!("Lemma that asserts that applying ", stringify!(BitVec::<$n>::from)," and then ", stringify!($name::from), " is the identity.")]
342                        #[hax_lib::fstar::before("[@@ $SIMPLIFICATION_LEMMA ]")]
343                        #[hax_lib::opaque]
344                        #[hax_lib::lemma]
345                        #[hax_lib::fstar::smt_pat(BitVec::[< to_ $name >](BitVec::[<from_ $name>](x)))]
346                        pub fn lemma_cancel_iv(x: $name) -> Proof<{
347                            hax_lib::eq(BitVec::[< to_ $name >](BitVec::[<from_ $name>](x)), x)
348                        }> {}
349                        #[doc = concat!("Lemma that asserts that applying ", stringify!($name::from)," and then ", stringify!(BitVec::<$n>::from), " is the identity.")]
350                        #[hax_lib::fstar::before("[@@ $SIMPLIFICATION_LEMMA ]")]
351                        #[hax_lib::opaque]
352                        #[hax_lib::lemma]
353                        #[hax_lib::fstar::smt_pat(BitVec::[< from_ $name >](BitVec::[<to_ $name>](x)))]
354                        pub fn lemma_cancel_bv(x: BitVec<$n>) -> Proof<{
355                            hax_lib::eq(BitVec::[< from_ $name >](BitVec::[<to_ $name>](x)), x)
356                        }> {}
357                    };
358                }
359            )*
360        };
361    }
362
363    // Defines the types `i32x8` and `i64x4`, and define intepretations function
364    // (`From` instances) from/to those types from/to bit vectors.
365    //
366    // We will need more such interpreations in the future to handle more avx2
367    // intrinsics (e.g. `_mm256_add_epi16` works on 16 bits integers, not on i32
368    // or i64).
369    interpretations!(256; i32x8 [i32; 8], i64x4 [i64; 4], i16x16 [i16; 16], i128x2 [i128; 2], i8x32 [i8; 32],
370		     u32x8 [u32; 8], u64x4 [u64; 4], u16x16 [u16; 16]);
371    interpretations!(128; i32x4 [i32; 4], i64x2 [i64; 2], i16x8 [i16; 8], i128x1 [i128; 1], i8x16 [i8; 16],
372		     u32x4 [u32; 4], u64x2 [u64; 2], u16x8 [u16; 8]);
373
374    impl i64x4 {
375        pub fn into_i32x8(self) -> i32x8 {
376            i32x8::from_fn(|i| {
377                let value = *self.get(i / 2);
378                (if i % 2 == 0 { value } else { value >> 32 }) as i32
379            })
380        }
381    }
382
383    impl i32x8 {
384        pub fn into_i64x4(self) -> i64x4 {
385            i64x4::from_fn(|i| {
386                let low = *self.get(2 * i) as u32 as u64;
387                let high = *self.get(2 * i + 1) as i32 as i64;
388                (high << 32) | low as i64
389            })
390        }
391    }
392
393    impl From<i64x4> for i32x8 {
394        fn from(vec: i64x4) -> Self {
395            vec.into_i32x8()
396        }
397    }
398
399    /// Lemma stating that converting an `i64x4` vector to a `BitVec<256>` and then into an `i32x8`
400    /// yields the same result as directly converting the `i64x4` into an `i32x8`.
401    #[hax_lib::fstar::before("[@@ $SIMPLIFICATION_LEMMA ]")]
402    #[hax_lib::opaque]
403    #[hax_lib::lemma]
404    pub fn lemma_rewrite_i64x4_bv_i32x8(
405        bv: i64x4,
406    ) -> Proof<{ hax_lib::eq(BitVec::to_i32x8(BitVec::from_i64x4(bv)), bv.into_i32x8()) }> {
407    }
408
409    /// Lemma stating that converting an `i64x4` vector to a `BitVec<256>` and then into an `i32x8`
410    /// yields the same result as directly converting the `i64x4` into an `i32x8`.
411    #[hax_lib::fstar::before("[@@ $SIMPLIFICATION_LEMMA ]")]
412    #[hax_lib::opaque]
413    #[hax_lib::lemma]
414    pub fn lemma_rewrite_i32x8_bv_i64x4(
415        bv: i32x8,
416    ) -> Proof<{ hax_lib::eq(BitVec::to_i64x4(BitVec::from_i32x8(bv)), bv.into_i64x4()) }> {
417    }
418
419    /// Normalize `from` calls that convert from one type to itself
420    #[hax_lib::fstar::replace(
421        r#"
422        [@@ $SIMPLIFICATION_LEMMA ]
423        let lemma (t: Type) (i: Core.Convert.t_From t t) (x: t)
424            : Lemma (Core.Convert.f_from #t #t #i x == (norm [primops; iota; delta; zeta] i.f_from) x)
425            = ()
426    "#
427    )]
428    const _: () = ();
429
430    #[cfg(test)]
431    mod direct_convertions_tests {
432        use super::*;
433        use crate::helpers::test::HasRandom;
434
435        #[test]
436        fn into_i32x8() {
437            for _ in 0..10000 {
438                let x: i64x4 = i64x4::random();
439                let y = x.into_i32x8();
440                assert_eq!(BitVec::from_i64x4(x), BitVec::from_i32x8(y));
441            }
442        }
443        #[test]
444        fn into_i64x4() {
445            let x: i32x8 = i32x8::random();
446            let y = x.into_i64x4();
447            assert_eq!(BitVec::from_i32x8(x), BitVec::from_i64x4(y));
448        }
449    }
450}