core_models/abstractions/
bitvec.rs1use super::bit::{Bit, MachineInteger};
3use super::funarr::*;
4
5use std::fmt::Formatter;
6
7#[cfg(hax)]
9use hax_lib::{int, ToInt};
10
11#[hax_lib::fstar::before("noeq")]
24#[derive(Copy, Clone, Eq, PartialEq)]
25pub struct BitVec<const N: u64>(FunArray<N, Bit>);
26
27#[hax_lib::exclude]
29fn bit_slice_to_string(bits: &[Bit]) -> String {
30 bits.iter()
31 .map(|bit| match bit {
32 Bit::Zero => '0',
33 Bit::One => '1',
34 })
35 .collect::<Vec<_>>()
36 .chunks(8)
37 .map(|bits| bits.iter().collect::<String>())
38 .map(|s| format!("{s} "))
39 .collect::<String>()
40 .trim()
41 .into()
42}
43
44#[hax_lib::exclude]
45impl<const N: u64> core::fmt::Debug for BitVec<N> {
46 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
47 write!(f, "{}", bit_slice_to_string(&self.0.as_vec()))
48 }
49}
50
51#[hax_lib::attributes]
52impl<const N: u64> core::ops::Index<u64> for BitVec<N> {
53 type Output = Bit;
54 #[requires(index < N)]
55 fn index(&self, index: u64) -> &Self::Output {
56 self.0.get(index)
57 }
58}
59
60#[hax_lib::exclude]
62fn u128_int_from_bit_slice(bits: &[Bit]) -> u128 {
63 bits.iter()
64 .enumerate()
65 .map(|(i, bit)| u128::from(*bit) << i)
66 .sum::<u128>()
67}
68
69#[hax_lib::exclude]
71fn int_from_bit_slice<T: TryFrom<i128> + MachineInteger + Copy>(bits: &[Bit]) -> T {
72 debug_assert!(bits.len() <= T::bits() as usize);
73 let result = if T::SIGNED {
74 let is_negative = matches!(bits[T::bits() as usize - 1], Bit::One);
75 let s = u128_int_from_bit_slice(&bits[0..T::bits() as usize - 1]) as i128;
76 if is_negative {
77 s + (-2i128).pow(T::bits() - 1)
78 } else {
79 s
80 }
81 } else {
82 u128_int_from_bit_slice(bits) as i128
83 };
84 let Ok(n) = result.try_into() else {
85 unreachable!()
87 };
88 n
89}
90
91#[hax_lib::fstar::replace(
92 r#"
93let ${BitVec::<0>::from_fn::<fn(u64)->Bit>}
94 (v_N: u64)
95 (f: (i: u64 {v i < v v_N}) -> $:{Bit})
96 : t_BitVec v_N =
97 ${BitVec::<0>}(${FunArray::<0,()>::from_fn::<fn(u64)->()>} v_N f)
98"#
99)]
100const _: () = ();
101
102macro_rules! impl_pointwise {
103 ($n:literal, $($i:literal)*) => {
104 impl BitVec<$n> {
105 pub fn pointwise(self) -> Self {
106 Self::from_fn(|i| match i {
107 $($i => self[$i],)*
108 _ => unreachable!(),
109 })
110 }
111 }
112 };
113}
114
115impl_pointwise!(128, 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127);
116impl_pointwise!(256, 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255);
117
118pub const REWRITE_RULE: () = {};
120
121#[hax_lib::exclude]
122impl<const N: u64> BitVec<N> {
123 pub fn from_fn<F: Fn(u64) -> Bit>(f: F) -> Self {
125 Self(FunArray::from_fn(f))
126 }
127 pub fn from_slice<T: Into<i128> + MachineInteger + Copy>(x: &[T], d: u64) -> Self {
129 Self::from_fn(|i| Bit::of_int::<T>(x[(i / d) as usize], (i % d) as u32))
130 }
131
132 pub fn from_int<T: Into<i128> + MachineInteger + Copy>(n: T) -> Self {
134 Self::from_slice::<T>(&[n], T::bits() as u64)
135 }
136
137 pub fn to_int<T: TryFrom<i128> + MachineInteger + Copy>(self) -> T {
139 int_from_bit_slice(&self.0.as_vec())
140 }
141
142 pub fn to_vec<T: TryFrom<i128> + MachineInteger + Copy>(&self) -> Vec<T> {
144 self.0
145 .as_vec()
146 .chunks(T::bits() as usize)
147 .map(int_from_bit_slice)
148 .collect()
149 }
150
151 pub fn rand() -> Self {
153 use rand::prelude::*;
154 let random_source: Vec<_> = {
155 let mut rng = rand::rng();
156 (0..N).map(|_| rng.random::<bool>()).collect()
157 };
158 Self::from_fn(|i| random_source[i as usize].into())
159 }
160}
161
162#[hax_lib::fstar::replace(
163 r#"
164open FStar.FunctionalExtensionality
165
166let extensionality' (#a: Type) (#b: Type) (f g: FStar.FunctionalExtensionality.(a ^-> b))
167 : Lemma (ensures (FStar.FunctionalExtensionality.feq f g <==> f == g))
168 = ()
169
170let mark_to_normalize #t (x: t): t = x
171
172open FStar.Tactics.V2
173#push-options "--z3rlimit 80 --admit_smt_queries true"
174let bitvec_rewrite_lemma_128 (x: $:{BitVec<128>})
175: Lemma (x == mark_to_normalize (${BitVec::<128>::pointwise} x)) =
176 let a = x._0._0 in
177 let b = (${BitVec::<128>::pointwise} x)._0._0 in
178 assert_norm (FStar.FunctionalExtensionality.feq a b);
179 extensionality' a b
180
181let bitvec_rewrite_lemma_256 (x: $:{BitVec<256>})
182: Lemma (x == mark_to_normalize (${BitVec::<256>::pointwise} x)) =
183 let a = x._0._0 in
184 let b = (${BitVec::<256>::pointwise} x)._0._0 in
185 assert_norm (FStar.FunctionalExtensionality.feq a b);
186 extensionality' a b
187#pop-options
188
189let bitvec_postprocess_norm_aux (): Tac unit = with_compat_pre_core 1 (fun () ->
190 let debug_mode = ext_enabled "debug_bv_postprocess_rewrite" in
191 let crate = match cur_module () with | crate::_ -> crate | _ -> fail "Empty module name" in
192 // Remove indirections
193 norm [primops; iota; delta_namespace [crate; "Libcrux_intrinsics"]; zeta_full];
194 // Rewrite call chains
195 let lemmas = FStar.List.Tot.map (fun f -> pack_ln (FStar.Stubs.Reflection.V2.Data.Tv_FVar f)) (lookup_attr (`${REWRITE_RULE}) (top_env ())) in
196 l_to_r lemmas;
197 /// Get rid of casts
198 norm [primops; iota; delta_namespace ["Rust_primitives"; "Prims.pow2"]; zeta_full];
199 if debug_mode then print ("[postprocess_rewrite_helper] lemmas = " ^ term_to_string (quote lemmas));
200
201 l_to_r [`bitvec_rewrite_lemma_128; `bitvec_rewrite_lemma_256];
202
203 let round _: Tac unit =
204 if debug_mode then dump "[postprocess_rewrite_helper] Rewrote goal";
205 // Normalize as much as possible
206 norm [primops; iota; delta_namespace ["Core"; crate; "Core_models"; "Libcrux_intrinsics"; "FStar.FunctionalExtensionality"; "Rust_primitives"]; zeta_full];
207 if debug_mode then print ("[postprocess_rewrite_helper] first norm done");
208 // Compute the last bits
209 // compute ();
210 // if debug_mode then dump ("[postprocess_rewrite_helper] compute done");
211 // Force full normalization
212 norm [primops; iota; delta; unascribe; zeta_full];
213 if debug_mode then dump "[postprocess_rewrite_helper] after full normalization";
214 // Solves the goal `<normalized body> == ?u`
215 trefl ()
216 in
217
218 ctrl_rewrite BottomUp (fun t ->
219 let f, args = collect_app t in
220 let matches = match inspect f with | Tv_UInst f _ | Tv_FVar f -> (inspect_fv f) = explode_qn (`%mark_to_normalize) | _ -> false in
221 let has_two_args = match args with | [_; _] -> true | _ -> false in
222 (matches && has_two_args, Continue)
223 ) round;
224
225 // Solves the goal `<normalized body> == ?u`
226 trefl ()
227)
228
229let ${bitvec_postprocess_norm} (): Tac unit =
230 if lax_on ()
231 then trefl () // don't bother rewritting the goal
232 else bitvec_postprocess_norm_aux ()
233"#
234)]
235pub fn bitvec_postprocess_norm() {}
241
242#[hax_lib::attributes]
243impl<const N: u64> BitVec<N> {
244 #[hax_lib::requires(CHUNK > 0 && CHUNK.to_int() * SHIFTS.to_int() == N.to_int())]
245 pub fn chunked_shift<const CHUNK: u64, const SHIFTS: u64>(
246 self,
247 shl: FunArray<SHIFTS, i128>,
248 ) -> BitVec<N> {
249 #[hax_lib::fstar::options("--z3rlimit 50 --split_queries always")]
251 #[hax_lib::requires(CHUNK > 0 && CHUNK.to_int() * SHIFTS.to_int() == N.to_int())]
252 fn chunked_shift<const N: u64, const CHUNK: u64, const SHIFTS: u64>(
253 bitvec: BitVec<N>,
254 shl: FunArray<SHIFTS, i128>,
255 ) -> BitVec<N> {
256 BitVec::from_fn(|i| {
257 let nth_bit = i % CHUNK;
258 let nth_chunk = i / CHUNK;
259 hax_lib::assert_prop!(nth_chunk.to_int() <= SHIFTS.to_int() - int!(1));
260 hax_lib::assert_prop!(
261 nth_chunk.to_int() * CHUNK.to_int()
262 <= (SHIFTS.to_int() - int!(1)) * CHUNK.to_int()
263 );
264 let shift: i128 = if nth_chunk < SHIFTS {
265 shl[nth_chunk]
266 } else {
267 0
268 };
269 let local_index = (nth_bit as i128).wrapping_sub(shift);
270 if local_index < CHUNK as i128 && local_index >= 0 {
271 let local_index = local_index as u64;
272 hax_lib::assert_prop!(
273 nth_chunk.to_int() * CHUNK.to_int() + local_index.to_int()
274 < SHIFTS.to_int() * CHUNK.to_int()
275 );
276 bitvec[nth_chunk * CHUNK + local_index]
277 } else {
278 Bit::Zero
279 }
280 })
281 }
282 chunked_shift::<N, CHUNK, SHIFTS>(self, shl)
283 }
284
285 pub fn fold<A>(&self, init: A, f: fn(A, Bit) -> A) -> A {
291 self.0.fold(init, f)
292 }
293}
294
295pub mod int_vec_interp {
296 use super::*;
298
299 #[allow(dead_code)]
301 #[hax_lib::fstar::before("irreducible")]
302 pub const SIMPLIFICATION_LEMMA: () = ();
303
304 macro_rules! interpretations {
307 ($n:literal; $($name:ident [$ty:ty; $m:literal]),*) => {
308 $(
309 #[doc = concat!(stringify!($ty), " vectors of size ", stringify!($m))]
310 #[allow(non_camel_case_types)]
311 pub type $name = FunArray<$m, $ty>;
312 pastey::paste! {
313 const _: () = {
314 #[hax_lib::opaque]
315 impl BitVec<$n> {
316 #[doc = concat!("Conversion from ", stringify!($ty), " vectors of size ", stringify!($m), "to bit vectors of size ", stringify!($n))]
317 pub fn [< from_ $name >](iv: $name) -> BitVec<$n> {
318 let vec: Vec<$ty> = iv.as_vec();
319 Self::from_slice(&vec[..], <$ty>::bits() as u64)
320 }
321 #[doc = concat!("Conversion from bit vectors of size ", stringify!($n), " to ", stringify!($ty), " vectors of size ", stringify!($m))]
322 pub fn [< to_ $name >](bv: BitVec<$n>) -> $name {
323 let vec: Vec<$ty> = bv.to_vec();
324 $name::from_fn(|i| vec[i as usize])
325 }
326 }
327
328 #[cfg(test)]
329 impl From<BitVec<$n>> for $name {
330 fn from(bv: BitVec<$n>) -> Self {
331 BitVec::[< to_ $name >](bv)
332 }
333 }
334 #[cfg(test)]
335 impl From<$name> for BitVec<$n> {
336 fn from(iv: $name) -> Self {
337 BitVec::[< from_ $name >](iv)
338 }
339 }
340 };
341 }
342 )*
343 };
344 }
345
346 interpretations!(256; i32x8 [i32; 8], i64x4 [i64; 4], i16x16 [i16; 16], i128x2 [i128; 2], i8x32 [i8; 32],
353 u32x8 [u32; 8], u64x4 [u64; 4], u16x16 [u16; 16]);
354 interpretations!(128; i32x4 [i32; 4], i64x2 [i64; 2], i16x8 [i16; 8], i128x1 [i128; 1], i8x16 [i8; 16],
355 u32x4 [u32; 4], u64x2 [u64; 2], u16x8 [u16; 8]);
356
357 impl i64x4 {
358 pub fn into_i32x8(self) -> i32x8 {
359 i32x8::from_fn(|i| {
360 let value = *self.get(i / 2);
361 (if i % 2 == 0 { value } else { value >> 32 }) as i32
362 })
363 }
364 }
365
366 impl i32x8 {
367 pub fn into_i64x4(self) -> i64x4 {
368 i64x4::from_fn(|i| {
369 let low = *self.get(2 * i) as u32 as u64;
370 let high = *self.get(2 * i + 1) as i32 as i64;
371 (high << 32) | low as i64
372 })
373 }
374 }
375
376 impl From<i64x4> for i32x8 {
377 fn from(vec: i64x4) -> Self {
378 vec.into_i32x8()
379 }
380 }
381
382 #[hax_lib::fstar::before("[@@ $SIMPLIFICATION_LEMMA ]")]
385 #[hax_lib::opaque]
386 #[hax_lib::lemma]
387 pub fn lemma_rewrite_i64x4_bv_i32x8(
388 bv: i64x4,
389 ) -> Proof<{ hax_lib::eq(BitVec::to_i32x8(BitVec::from_i64x4(bv)), bv.into_i32x8()) }> {
390 }
391
392 #[hax_lib::fstar::before("[@@ $SIMPLIFICATION_LEMMA ]")]
395 #[hax_lib::opaque]
396 #[hax_lib::lemma]
397 pub fn lemma_rewrite_i32x8_bv_i64x4(
398 bv: i32x8,
399 ) -> Proof<{ hax_lib::eq(BitVec::to_i64x4(BitVec::from_i32x8(bv)), bv.into_i64x4()) }> {
400 }
401
402 #[hax_lib::fstar::replace(
404 r#"
405 [@@ $SIMPLIFICATION_LEMMA ]
406 let lemma (t: Type) (i: Core.Convert.t_From t t) (x: t)
407 : Lemma (Core.Convert.f_from #t #t #i x == (norm [primops; iota; delta; zeta] i.f_from) x)
408 = ()
409 "#
410 )]
411 const _: () = ();
412
413 #[cfg(test)]
414 mod direct_convertions_tests {
415 use super::*;
416 use crate::helpers::test::HasRandom;
417
418 #[test]
419 fn into_i32x8() {
420 for _ in 0..10000 {
421 let x: i64x4 = i64x4::random();
422 let y = x.into_i32x8();
423 assert_eq!(BitVec::from_i64x4(x), BitVec::from_i32x8(y));
424 }
425 }
426 #[test]
427 fn into_i64x4() {
428 let x: i32x8 = i32x8::random();
429 let y = x.into_i64x4();
430 assert_eq!(BitVec::from_i32x8(x), BitVec::from_i64x4(y));
431 }
432 }
433}