1use crate::ternary::Trit;
51use crate::vsa::SparseVec;
52
53#[cfg(target_arch = "x86_64")]
55fn has_avx512_vpopcntdq() -> bool {
56 std::arch::is_x86_feature_detected!("avx512f")
58 && std::arch::is_x86_feature_detected!("avx512vpopcntdq")
59}
60
61#[cfg(target_arch = "x86_64")]
62fn has_avx2() -> bool {
63 std::arch::is_x86_feature_detected!("avx2")
64}
65
66#[derive(Clone, Debug, PartialEq, Eq)]
67pub struct PackedTritVec {
68 len: usize,
69 data: Vec<u64>,
70}
71
72impl PackedTritVec {
73 const MASK_EVEN_BITS: u64 = 0x5555_5555_5555_5555;
74
75 #[inline]
76 fn ensure_len_and_clear(&mut self, len: usize) {
77 self.len = len;
78 let words = Self::word_count_for_len(len);
79 if self.data.len() != words {
80 self.data.resize(words, 0u64);
81 }
82 self.data.fill(0u64);
83 }
84
85 pub fn new_zero(len: usize) -> Self {
86 let bits = len.saturating_mul(2);
87 let words = bits.div_ceil(64);
88 Self {
89 len,
90 data: vec![0u64; words],
91 }
92 }
93
94 #[inline]
95 fn word_count_for_len(len: usize) -> usize {
96 let bits = len.saturating_mul(2);
97 bits.div_ceil(64)
98 }
99
100 #[inline]
101 fn last_word_mask(len: usize) -> u64 {
102 let lanes_in_last = len % 32;
103 if lanes_in_last == 0 {
104 !0u64
105 } else {
106 let used_bits = lanes_in_last * 2;
107 if used_bits >= 64 {
108 !0u64
109 } else {
110 (1u64 << used_bits) - 1
111 }
112 }
113 }
114
115 pub fn len(&self) -> usize {
116 self.len
117 }
118
119 pub fn is_empty(&self) -> bool {
120 self.len == 0
121 }
122
123 #[inline]
124 fn word_bit_index(i: usize) -> (usize, usize) {
125 let bit = i * 2;
126 (bit / 64, bit % 64)
127 }
128
129 pub fn get(&self, i: usize) -> Trit {
130 if i >= self.len {
131 return Trit::Z;
132 }
133 let (word, bit) = Self::word_bit_index(i);
134 let w = self.data.get(word).copied().unwrap_or(0);
135 let v = (w >> bit) & 0b11;
136 match v {
137 0b01 => Trit::P,
138 0b10 => Trit::N,
139 _ => Trit::Z,
140 }
141 }
142
143 pub fn set(&mut self, i: usize, t: Trit) {
144 if i >= self.len {
145 return;
146 }
147 let (word, bit) = Self::word_bit_index(i);
148 if let Some(w) = self.data.get_mut(word) {
149 *w &= !(0b11u64 << bit);
150 let enc = match t {
151 Trit::Z => 0b00u64,
152 Trit::P => 0b01u64,
153 Trit::N => 0b10u64,
154 };
155 *w |= enc << bit;
156 }
157 }
158
159 pub fn from_sparsevec(vec: &SparseVec, len: usize) -> Self {
160 let mut out = Self::new_zero(len);
161 out.fill_from_sparsevec(vec, len);
162 out
163 }
164
165 pub fn fill_from_sparsevec(&mut self, vec: &SparseVec, len: usize) {
169 self.ensure_len_and_clear(len);
170
171 for &idx in &vec.pos {
174 if idx < len {
175 let bit = idx * 2;
176 let word = bit / 64;
177 let shift = bit % 64;
178 self.data[word] |= 1u64 << shift;
179 }
180 }
181
182 for &idx in &vec.neg {
183 if idx < len {
184 let bit = idx * 2;
185 let word = bit / 64;
186 let shift = bit % 64;
187 self.data[word] |= 1u64 << (shift + 1);
188 }
189 }
190
191 if !self.data.is_empty() {
192 let last = self.data.len() - 1;
193 self.data[last] &= Self::last_word_mask(self.len);
194 }
195 }
196
197 pub fn to_sparsevec(&self) -> SparseVec {
198 let mut pos: Vec<usize> = Vec::new();
199 let mut neg: Vec<usize> = Vec::new();
200
201 for (word_idx, &word_raw) in self.data.iter().enumerate() {
203 let mut word = word_raw;
204 if word_idx + 1 == self.data.len() {
205 word &= Self::last_word_mask(self.len);
206 }
207
208 let pos_bits = word & Self::MASK_EVEN_BITS;
211 let neg_bits = (word >> 1) & Self::MASK_EVEN_BITS;
212
213 let conflict_bits = pos_bits & neg_bits;
217 let clean_pos = pos_bits & !conflict_bits;
218 let clean_neg = neg_bits & !conflict_bits;
219
220 let mut m = clean_pos;
222 while m != 0 {
223 let tz = m.trailing_zeros() as usize;
224 let lane = tz / 2;
225 let idx = word_idx * 32 + lane;
226 if idx < self.len {
227 pos.push(idx);
228 }
229 m &= m - 1;
230 }
231
232 let mut n = clean_neg;
234 while n != 0 {
235 let tz = n.trailing_zeros() as usize;
236 let lane = tz / 2;
237 let idx = word_idx * 32 + lane;
238 if idx < self.len {
239 neg.push(idx);
240 }
241 n &= n - 1;
242 }
243 }
244
245 SparseVec { pos, neg }
246 }
247
248 pub fn dot(&self, other: &Self) -> i32 {
255 let n = self.len.min(other.len);
256 if n == 0 {
257 return 0;
258 }
259
260 #[cfg(target_arch = "x86_64")]
261 {
262 let words = Self::word_count_for_len(n)
263 .min(self.data.len())
264 .min(other.data.len());
265
266 if words >= 16 && has_avx512_vpopcntdq() {
268 return unsafe { self.dot_avx512(other, n) };
269 }
270
271 if words >= 8 && has_avx2() {
273 return unsafe { self.dot_avx2(other, n) };
274 }
275 }
276
277 self.dot_scalar(other, n)
279 }
280
281 #[inline]
283 fn dot_scalar(&self, other: &Self, n: usize) -> i32 {
284 let words = Self::word_count_for_len(n)
285 .min(self.data.len())
286 .min(other.data.len());
287
288 let mut acc: i32 = 0;
289 for w in 0..words {
290 let mut a = self.data[w];
291 let mut b = other.data[w];
292 if w + 1 == words {
293 let mask = Self::last_word_mask(n);
294 a &= mask;
295 b &= mask;
296 }
297
298 let a_pos = a & Self::MASK_EVEN_BITS;
299 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
300 let b_pos = b & Self::MASK_EVEN_BITS;
301 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
302
303 let pp = (a_pos & b_pos).count_ones() as i32;
304 let nn = (a_neg & b_neg).count_ones() as i32;
305 let pn = (a_pos & b_neg).count_ones() as i32;
306 let np = (a_neg & b_pos).count_ones() as i32;
307
308 acc += (pp + nn) - (pn + np);
309 }
310
311 acc
312 }
313
314 #[cfg(target_arch = "x86_64")]
321 #[target_feature(enable = "avx512f", enable = "avx512vpopcntdq")]
322 unsafe fn dot_avx512(&self, other: &Self, n: usize) -> i32 {
323 use std::arch::x86_64::*;
324
325 let words = Self::word_count_for_len(n)
326 .min(self.data.len())
327 .min(other.data.len());
328
329 let mask_even = _mm512_set1_epi64(Self::MASK_EVEN_BITS as i64);
331
332 let mut acc_pos = _mm512_setzero_si512(); let mut acc_neg = _mm512_setzero_si512(); let chunks = words / 8;
338 for chunk in 0..chunks {
339 let base = chunk * 8;
340
341 let va = _mm512_loadu_si512(self.data[base..].as_ptr() as *const __m512i);
343 let vb = _mm512_loadu_si512(other.data[base..].as_ptr() as *const __m512i);
344
345 let a_pos = _mm512_and_si512(va, mask_even);
347 let a_neg = _mm512_and_si512(_mm512_srli_epi64(va, 1), mask_even);
348 let b_pos = _mm512_and_si512(vb, mask_even);
349 let b_neg = _mm512_and_si512(_mm512_srli_epi64(vb, 1), mask_even);
350
351 let pp = _mm512_and_si512(a_pos, b_pos); let nn = _mm512_and_si512(a_neg, b_neg); let pn = _mm512_and_si512(a_pos, b_neg); let np = _mm512_and_si512(a_neg, b_pos); let pp_cnt = _mm512_popcnt_epi64(pp);
359 let nn_cnt = _mm512_popcnt_epi64(nn);
360 let pn_cnt = _mm512_popcnt_epi64(pn);
361 let np_cnt = _mm512_popcnt_epi64(np);
362
363 acc_pos = _mm512_add_epi64(acc_pos, pp_cnt);
365 acc_pos = _mm512_add_epi64(acc_pos, nn_cnt);
366
367 acc_neg = _mm512_add_epi64(acc_neg, pn_cnt);
369 acc_neg = _mm512_add_epi64(acc_neg, np_cnt);
370 }
371
372 let pos_sum = _mm512_reduce_add_epi64(acc_pos);
374 let neg_sum = _mm512_reduce_add_epi64(acc_neg);
375 let mut acc = (pos_sum - neg_sum) as i32;
376
377 let remainder_start = chunks * 8;
379 for w in remainder_start..words {
380 let mut a = self.data[w];
381 let mut b = other.data[w];
382 if w + 1 == words {
383 let mask = Self::last_word_mask(n);
384 a &= mask;
385 b &= mask;
386 }
387
388 let a_pos = a & Self::MASK_EVEN_BITS;
389 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
390 let b_pos = b & Self::MASK_EVEN_BITS;
391 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
392
393 let pp = (a_pos & b_pos).count_ones() as i32;
394 let nn = (a_neg & b_neg).count_ones() as i32;
395 let pn = (a_pos & b_neg).count_ones() as i32;
396 let np = (a_neg & b_pos).count_ones() as i32;
397
398 acc += (pp + nn) - (pn + np);
399 }
400
401 acc
402 }
403
404 #[cfg(target_arch = "x86_64")]
412 #[target_feature(enable = "avx2")]
413 unsafe fn dot_avx2(&self, other: &Self, n: usize) -> i32 {
414 use std::arch::x86_64::*;
415
416 let words = Self::word_count_for_len(n)
417 .min(self.data.len())
418 .min(other.data.len());
419
420 let mask_even = _mm256_set1_epi64x(Self::MASK_EVEN_BITS as i64);
422
423 let popcount_lut = _mm256_setr_epi8(
425 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2,
426 3, 3, 4,
427 );
428 let low_nibble_mask = _mm256_set1_epi8(0x0F);
429
430 let mut acc_pos = _mm256_setzero_si256();
432 let mut acc_neg = _mm256_setzero_si256();
433
434 let chunks = words / 4;
436 for chunk in 0..chunks {
437 let base = chunk * 4;
438
439 let va = _mm256_loadu_si256(self.data[base..].as_ptr() as *const __m256i);
441 let vb = _mm256_loadu_si256(other.data[base..].as_ptr() as *const __m256i);
442
443 let a_pos = _mm256_and_si256(va, mask_even);
445 let a_neg = _mm256_and_si256(_mm256_srli_epi64(va, 1), mask_even);
446 let b_pos = _mm256_and_si256(vb, mask_even);
447 let b_neg = _mm256_and_si256(_mm256_srli_epi64(vb, 1), mask_even);
448
449 let pp = _mm256_and_si256(a_pos, b_pos);
451 let nn = _mm256_and_si256(a_neg, b_neg);
452 let pn = _mm256_and_si256(a_pos, b_neg);
453 let np = _mm256_and_si256(a_neg, b_pos);
454
455 let pp_lo = _mm256_shuffle_epi8(popcount_lut, _mm256_and_si256(pp, low_nibble_mask));
457 let pp_hi = _mm256_shuffle_epi8(
458 popcount_lut,
459 _mm256_and_si256(_mm256_srli_epi16(pp, 4), low_nibble_mask),
460 );
461 let nn_lo = _mm256_shuffle_epi8(popcount_lut, _mm256_and_si256(nn, low_nibble_mask));
462 let nn_hi = _mm256_shuffle_epi8(
463 popcount_lut,
464 _mm256_and_si256(_mm256_srli_epi16(nn, 4), low_nibble_mask),
465 );
466
467 let pn_lo = _mm256_shuffle_epi8(popcount_lut, _mm256_and_si256(pn, low_nibble_mask));
468 let pn_hi = _mm256_shuffle_epi8(
469 popcount_lut,
470 _mm256_and_si256(_mm256_srli_epi16(pn, 4), low_nibble_mask),
471 );
472 let np_lo = _mm256_shuffle_epi8(popcount_lut, _mm256_and_si256(np, low_nibble_mask));
473 let np_hi = _mm256_shuffle_epi8(
474 popcount_lut,
475 _mm256_and_si256(_mm256_srli_epi16(np, 4), low_nibble_mask),
476 );
477
478 let pos_bytes =
480 _mm256_add_epi8(_mm256_add_epi8(pp_lo, pp_hi), _mm256_add_epi8(nn_lo, nn_hi));
481 let neg_bytes =
482 _mm256_add_epi8(_mm256_add_epi8(pn_lo, pn_hi), _mm256_add_epi8(np_lo, np_hi));
483
484 let pos_sad = _mm256_sad_epu8(pos_bytes, _mm256_setzero_si256());
486 let neg_sad = _mm256_sad_epu8(neg_bytes, _mm256_setzero_si256());
487
488 acc_pos = _mm256_add_epi64(acc_pos, pos_sad);
489 acc_neg = _mm256_add_epi64(acc_neg, neg_sad);
490 }
491
492 let pos_lo = _mm256_castsi256_si128(acc_pos);
494 let pos_hi = _mm256_extracti128_si256(acc_pos, 1);
495 let pos_sum128 = _mm_add_epi64(pos_lo, pos_hi);
496
497 let neg_lo = _mm256_castsi256_si128(acc_neg);
498 let neg_hi = _mm256_extracti128_si256(acc_neg, 1);
499 let neg_sum128 = _mm_add_epi64(neg_lo, neg_hi);
500
501 let pos_final = _mm_extract_epi64(pos_sum128, 0) + _mm_extract_epi64(pos_sum128, 1);
502 let neg_final = _mm_extract_epi64(neg_sum128, 0) + _mm_extract_epi64(neg_sum128, 1);
503
504 let mut acc = (pos_final - neg_final) as i32;
505
506 let remainder_start = chunks * 4;
508 for w in remainder_start..words {
509 let mut a = self.data[w];
510 let mut b = other.data[w];
511 if w + 1 == words {
512 let mask = Self::last_word_mask(n);
513 a &= mask;
514 b &= mask;
515 }
516
517 let a_pos = a & Self::MASK_EVEN_BITS;
518 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
519 let b_pos = b & Self::MASK_EVEN_BITS;
520 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
521
522 let pp = (a_pos & b_pos).count_ones() as i32;
523 let nn = (a_neg & b_neg).count_ones() as i32;
524 let pn = (a_pos & b_neg).count_ones() as i32;
525 let np = (a_neg & b_pos).count_ones() as i32;
526
527 acc += (pp + nn) - (pn + np);
528 }
529
530 acc
531 }
532
533 pub fn bind(&self, other: &Self) -> Self {
540 let n = self.len.min(other.len);
541 if n == 0 {
542 return Self::new_zero(0);
543 }
544
545 let words = Self::word_count_for_len(n)
546 .min(self.data.len())
547 .min(other.data.len());
548 let mut out = Self::new_zero(n);
549
550 #[cfg(target_arch = "x86_64")]
551 {
552 if words >= 8 && std::arch::is_x86_feature_detected!("avx512f") {
554 unsafe { self.bind_avx512(other, n, &mut out) };
555 return out;
556 }
557
558 if words >= 4 && has_avx2() {
560 unsafe { self.bind_avx2(other, n, &mut out) };
561 return out;
562 }
563 }
564
565 self.bind_scalar(other, n, &mut out);
567 out
568 }
569
570 #[inline]
572 fn bind_scalar(&self, other: &Self, n: usize, out: &mut Self) {
573 let words = Self::word_count_for_len(n)
574 .min(self.data.len())
575 .min(other.data.len())
576 .min(out.data.len());
577
578 for w in 0..words {
579 let mut a = self.data[w];
580 let mut b = other.data[w];
581 if w + 1 == words {
582 let mask = Self::last_word_mask(n);
583 a &= mask;
584 b &= mask;
585 }
586
587 let a_pos = a & Self::MASK_EVEN_BITS;
588 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
589 let b_pos = b & Self::MASK_EVEN_BITS;
590 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
591
592 let same = (a_pos & b_pos) | (a_neg & b_neg);
593 let opp = (a_pos & b_neg) | (a_neg & b_pos);
594
595 out.data[w] = same | (opp << 1);
596 }
597
598 if !out.data.is_empty() {
600 let last = out.data.len() - 1;
601 out.data[last] &= Self::last_word_mask(out.len);
602 }
603 }
604
605 #[cfg(target_arch = "x86_64")]
612 #[target_feature(enable = "avx512f")]
613 unsafe fn bind_avx512(&self, other: &Self, n: usize, out: &mut Self) {
614 use std::arch::x86_64::*;
615
616 let words = Self::word_count_for_len(n)
617 .min(self.data.len())
618 .min(other.data.len())
619 .min(out.data.len());
620
621 let mask_even = _mm512_set1_epi64(Self::MASK_EVEN_BITS as i64);
622
623 let chunks = words / 8;
625 for chunk in 0..chunks {
626 let base = chunk * 8;
627
628 let va = _mm512_loadu_si512(self.data[base..].as_ptr() as *const __m512i);
629 let vb = _mm512_loadu_si512(other.data[base..].as_ptr() as *const __m512i);
630
631 let a_pos = _mm512_and_si512(va, mask_even);
633 let a_neg = _mm512_and_si512(_mm512_srli_epi64(va, 1), mask_even);
634 let b_pos = _mm512_and_si512(vb, mask_even);
635 let b_neg = _mm512_and_si512(_mm512_srli_epi64(vb, 1), mask_even);
636
637 let same = _mm512_or_si512(
639 _mm512_and_si512(a_pos, b_pos),
640 _mm512_and_si512(a_neg, b_neg),
641 );
642 let opp = _mm512_or_si512(
643 _mm512_and_si512(a_pos, b_neg),
644 _mm512_and_si512(a_neg, b_pos),
645 );
646
647 let result = _mm512_or_si512(same, _mm512_slli_epi64(opp, 1));
648 _mm512_storeu_si512(out.data[base..].as_mut_ptr() as *mut __m512i, result);
649 }
650
651 let remainder_start = chunks * 8;
653 for w in remainder_start..words {
654 let mut a = self.data[w];
655 let mut b = other.data[w];
656 if w + 1 == words {
657 let mask = Self::last_word_mask(n);
658 a &= mask;
659 b &= mask;
660 }
661
662 let a_pos = a & Self::MASK_EVEN_BITS;
663 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
664 let b_pos = b & Self::MASK_EVEN_BITS;
665 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
666
667 let same = (a_pos & b_pos) | (a_neg & b_neg);
668 let opp = (a_pos & b_neg) | (a_neg & b_pos);
669
670 out.data[w] = same | (opp << 1);
671 }
672
673 if !out.data.is_empty() {
674 let last = out.data.len() - 1;
675 out.data[last] &= Self::last_word_mask(out.len);
676 }
677 }
678
679 #[cfg(target_arch = "x86_64")]
686 #[target_feature(enable = "avx2")]
687 unsafe fn bind_avx2(&self, other: &Self, n: usize, out: &mut Self) {
688 use std::arch::x86_64::*;
689
690 let words = Self::word_count_for_len(n)
691 .min(self.data.len())
692 .min(other.data.len())
693 .min(out.data.len());
694
695 let mask_even = _mm256_set1_epi64x(Self::MASK_EVEN_BITS as i64);
696
697 let chunks = words / 4;
699 for chunk in 0..chunks {
700 let base = chunk * 4;
701
702 let va = _mm256_loadu_si256(self.data[base..].as_ptr() as *const __m256i);
703 let vb = _mm256_loadu_si256(other.data[base..].as_ptr() as *const __m256i);
704
705 let a_pos = _mm256_and_si256(va, mask_even);
707 let a_neg = _mm256_and_si256(_mm256_srli_epi64(va, 1), mask_even);
708 let b_pos = _mm256_and_si256(vb, mask_even);
709 let b_neg = _mm256_and_si256(_mm256_srli_epi64(vb, 1), mask_even);
710
711 let same = _mm256_or_si256(
713 _mm256_and_si256(a_pos, b_pos),
714 _mm256_and_si256(a_neg, b_neg),
715 );
716 let opp = _mm256_or_si256(
717 _mm256_and_si256(a_pos, b_neg),
718 _mm256_and_si256(a_neg, b_pos),
719 );
720
721 let result = _mm256_or_si256(same, _mm256_slli_epi64(opp, 1));
722 _mm256_storeu_si256(out.data[base..].as_mut_ptr() as *mut __m256i, result);
723 }
724
725 let remainder_start = chunks * 4;
727 for w in remainder_start..words {
728 let mut a = self.data[w];
729 let mut b = other.data[w];
730 if w + 1 == words {
731 let mask = Self::last_word_mask(n);
732 a &= mask;
733 b &= mask;
734 }
735
736 let a_pos = a & Self::MASK_EVEN_BITS;
737 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
738 let b_pos = b & Self::MASK_EVEN_BITS;
739 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
740
741 let same = (a_pos & b_pos) | (a_neg & b_neg);
742 let opp = (a_pos & b_neg) | (a_neg & b_pos);
743
744 out.data[w] = same | (opp << 1);
745 }
746
747 if !out.data.is_empty() {
748 let last = out.data.len() - 1;
749 out.data[last] &= Self::last_word_mask(out.len);
750 }
751 }
752
753 pub fn bind_into(&self, other: &Self, out: &mut Self) {
755 let n = self.len.min(other.len);
756 out.ensure_len_and_clear(n);
757 if n == 0 {
758 return;
759 }
760
761 let words = Self::word_count_for_len(n)
762 .min(self.data.len())
763 .min(other.data.len())
764 .min(out.data.len());
765
766 for w in 0..words {
767 let mut a = self.data[w];
768 let mut b = other.data[w];
769 if w + 1 == words {
770 let mask = Self::last_word_mask(n);
771 a &= mask;
772 b &= mask;
773 }
774
775 let a_pos = a & Self::MASK_EVEN_BITS;
776 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
777 let b_pos = b & Self::MASK_EVEN_BITS;
778 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
779
780 let same = (a_pos & b_pos) | (a_neg & b_neg);
781 let opp = (a_pos & b_neg) | (a_neg & b_pos);
782
783 out.data[w] = same | (opp << 1);
784 }
785
786 if !out.data.is_empty() {
787 let last = out.data.len() - 1;
788 out.data[last] &= Self::last_word_mask(out.len);
789 }
790 }
791
792 pub fn bundle(&self, other: &Self) -> Self {
799 let n = self.len.min(other.len);
800 if n == 0 {
801 return Self::new_zero(0);
802 }
803
804 let words = Self::word_count_for_len(n)
805 .min(self.data.len())
806 .min(other.data.len());
807 let mut out = Self::new_zero(n);
808
809 #[cfg(target_arch = "x86_64")]
810 {
811 if words >= 8 && std::arch::is_x86_feature_detected!("avx512f") {
813 unsafe { self.bundle_avx512(other, n, &mut out) };
814 return out;
815 }
816
817 if words >= 4 && has_avx2() {
819 unsafe { self.bundle_avx2(other, n, &mut out) };
820 return out;
821 }
822 }
823
824 self.bundle_scalar(other, n, &mut out);
826 out
827 }
828
829 #[inline]
831 fn bundle_scalar(&self, other: &Self, n: usize, out: &mut Self) {
832 let words = Self::word_count_for_len(n)
833 .min(self.data.len())
834 .min(other.data.len())
835 .min(out.data.len());
836
837 for w in 0..words {
838 let mut a = self.data[w];
839 let mut b = other.data[w];
840 if w + 1 == words {
841 let mask = Self::last_word_mask(n);
842 a &= mask;
843 b &= mask;
844 }
845
846 let a_pos = a & Self::MASK_EVEN_BITS;
847 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
848 let b_pos = b & Self::MASK_EVEN_BITS;
849 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
850
851 let mask = Self::MASK_EVEN_BITS;
852 let not_b_neg = (!b_neg) & mask;
853 let not_a_neg = (!a_neg) & mask;
854 let not_b_pos = (!b_pos) & mask;
855 let not_a_pos = (!a_pos) & mask;
856
857 let pos = (a_pos & not_b_neg) | (b_pos & not_a_neg);
858 let neg = (a_neg & not_b_pos) | (b_neg & not_a_pos);
859
860 out.data[w] = pos | (neg << 1);
861 }
862
863 if !out.data.is_empty() {
864 let last = out.data.len() - 1;
865 out.data[last] &= Self::last_word_mask(out.len);
866 }
867 }
868
869 #[cfg(target_arch = "x86_64")]
876 #[target_feature(enable = "avx512f")]
877 unsafe fn bundle_avx512(&self, other: &Self, n: usize, out: &mut Self) {
878 use std::arch::x86_64::*;
879
880 let words = Self::word_count_for_len(n)
881 .min(self.data.len())
882 .min(other.data.len())
883 .min(out.data.len());
884
885 let mask_even = _mm512_set1_epi64(Self::MASK_EVEN_BITS as i64);
886
887 let chunks = words / 8;
889 for chunk in 0..chunks {
890 let base = chunk * 8;
891
892 let va = _mm512_loadu_si512(self.data[base..].as_ptr() as *const __m512i);
893 let vb = _mm512_loadu_si512(other.data[base..].as_ptr() as *const __m512i);
894
895 let a_pos = _mm512_and_si512(va, mask_even);
897 let a_neg = _mm512_and_si512(_mm512_srli_epi64(va, 1), mask_even);
898 let b_pos = _mm512_and_si512(vb, mask_even);
899 let b_neg = _mm512_and_si512(_mm512_srli_epi64(vb, 1), mask_even);
900
901 let not_b_neg = _mm512_andnot_si512(b_neg, mask_even);
905 let not_a_neg = _mm512_andnot_si512(a_neg, mask_even);
906 let not_b_pos = _mm512_andnot_si512(b_pos, mask_even);
907 let not_a_pos = _mm512_andnot_si512(a_pos, mask_even);
908
909 let pos = _mm512_or_si512(
910 _mm512_and_si512(a_pos, not_b_neg),
911 _mm512_and_si512(b_pos, not_a_neg),
912 );
913 let neg = _mm512_or_si512(
914 _mm512_and_si512(a_neg, not_b_pos),
915 _mm512_and_si512(b_neg, not_a_pos),
916 );
917
918 let result = _mm512_or_si512(pos, _mm512_slli_epi64(neg, 1));
919 _mm512_storeu_si512(out.data[base..].as_mut_ptr() as *mut __m512i, result);
920 }
921
922 let remainder_start = chunks * 8;
924 for w in remainder_start..words {
925 let mut a = self.data[w];
926 let mut b = other.data[w];
927 if w + 1 == words {
928 let mask = Self::last_word_mask(n);
929 a &= mask;
930 b &= mask;
931 }
932
933 let a_pos = a & Self::MASK_EVEN_BITS;
934 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
935 let b_pos = b & Self::MASK_EVEN_BITS;
936 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
937
938 let mask = Self::MASK_EVEN_BITS;
939 let not_b_neg = (!b_neg) & mask;
940 let not_a_neg = (!a_neg) & mask;
941 let not_b_pos = (!b_pos) & mask;
942 let not_a_pos = (!a_pos) & mask;
943
944 let pos = (a_pos & not_b_neg) | (b_pos & not_a_neg);
945 let neg = (a_neg & not_b_pos) | (b_neg & not_a_pos);
946
947 out.data[w] = pos | (neg << 1);
948 }
949
950 if !out.data.is_empty() {
951 let last = out.data.len() - 1;
952 out.data[last] &= Self::last_word_mask(out.len);
953 }
954 }
955
956 #[cfg(target_arch = "x86_64")]
963 #[target_feature(enable = "avx2")]
964 unsafe fn bundle_avx2(&self, other: &Self, n: usize, out: &mut Self) {
965 use std::arch::x86_64::*;
966
967 let words = Self::word_count_for_len(n)
968 .min(self.data.len())
969 .min(other.data.len())
970 .min(out.data.len());
971
972 let mask_even = _mm256_set1_epi64x(Self::MASK_EVEN_BITS as i64);
973
974 let chunks = words / 4;
976 for chunk in 0..chunks {
977 let base = chunk * 4;
978
979 let va = _mm256_loadu_si256(self.data[base..].as_ptr() as *const __m256i);
980 let vb = _mm256_loadu_si256(other.data[base..].as_ptr() as *const __m256i);
981
982 let a_pos = _mm256_and_si256(va, mask_even);
984 let a_neg = _mm256_and_si256(_mm256_srli_epi64(va, 1), mask_even);
985 let b_pos = _mm256_and_si256(vb, mask_even);
986 let b_neg = _mm256_and_si256(_mm256_srli_epi64(vb, 1), mask_even);
987
988 let not_b_neg = _mm256_andnot_si256(b_neg, mask_even);
990 let not_a_neg = _mm256_andnot_si256(a_neg, mask_even);
991 let not_b_pos = _mm256_andnot_si256(b_pos, mask_even);
992 let not_a_pos = _mm256_andnot_si256(a_pos, mask_even);
993
994 let pos = _mm256_or_si256(
995 _mm256_and_si256(a_pos, not_b_neg),
996 _mm256_and_si256(b_pos, not_a_neg),
997 );
998 let neg = _mm256_or_si256(
999 _mm256_and_si256(a_neg, not_b_pos),
1000 _mm256_and_si256(b_neg, not_a_pos),
1001 );
1002
1003 let result = _mm256_or_si256(pos, _mm256_slli_epi64(neg, 1));
1004 _mm256_storeu_si256(out.data[base..].as_mut_ptr() as *mut __m256i, result);
1005 }
1006
1007 let remainder_start = chunks * 4;
1009 for w in remainder_start..words {
1010 let mut a = self.data[w];
1011 let mut b = other.data[w];
1012 if w + 1 == words {
1013 let mask = Self::last_word_mask(n);
1014 a &= mask;
1015 b &= mask;
1016 }
1017
1018 let a_pos = a & Self::MASK_EVEN_BITS;
1019 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
1020 let b_pos = b & Self::MASK_EVEN_BITS;
1021 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
1022
1023 let mask = Self::MASK_EVEN_BITS;
1024 let not_b_neg = (!b_neg) & mask;
1025 let not_a_neg = (!a_neg) & mask;
1026 let not_b_pos = (!b_pos) & mask;
1027 let not_a_pos = (!a_pos) & mask;
1028
1029 let pos = (a_pos & not_b_neg) | (b_pos & not_a_neg);
1030 let neg = (a_neg & not_b_pos) | (b_neg & not_a_pos);
1031
1032 out.data[w] = pos | (neg << 1);
1033 }
1034
1035 if !out.data.is_empty() {
1036 let last = out.data.len() - 1;
1037 out.data[last] &= Self::last_word_mask(out.len);
1038 }
1039 }
1040
1041 pub fn bundle_into(&self, other: &Self, out: &mut Self) {
1043 let n = self.len.min(other.len);
1044 out.ensure_len_and_clear(n);
1045 if n == 0 {
1046 return;
1047 }
1048
1049 let words = Self::word_count_for_len(n)
1050 .min(self.data.len())
1051 .min(other.data.len())
1052 .min(out.data.len());
1053
1054 for w in 0..words {
1055 let mut a = self.data[w];
1056 let mut b = other.data[w];
1057 if w + 1 == words {
1058 let mask = Self::last_word_mask(n);
1059 a &= mask;
1060 b &= mask;
1061 }
1062
1063 let a_pos = a & Self::MASK_EVEN_BITS;
1064 let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
1065 let b_pos = b & Self::MASK_EVEN_BITS;
1066 let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
1067
1068 let mask = Self::MASK_EVEN_BITS;
1069 let not_b_neg = (!b_neg) & mask;
1070 let not_a_neg = (!a_neg) & mask;
1071 let not_b_pos = (!b_pos) & mask;
1072 let not_a_pos = (!a_pos) & mask;
1073
1074 let pos = (a_pos & not_b_neg) | (b_pos & not_a_neg);
1075 let neg = (a_neg & not_b_pos) | (b_neg & not_a_pos);
1076
1077 out.data[w] = pos | (neg << 1);
1078 }
1079
1080 if !out.data.is_empty() {
1081 let last = out.data.len() - 1;
1082 out.data[last] &= Self::last_word_mask(out.len);
1083 }
1084 }
1085}