Skip to main content

embeddenator_vsa/
ternary_vec.rs

1//! Packed ternary vector representation.
2//!
3//! This module implements the bitsliced balanced ternary representation that serves as
4//! the internal substrate for fast dot/bind/bundle operations.
5//!
6//! # Bitsliced Design
7//!
8//! See `docs/BITSLICED_TERNARY_DESIGN.md` for comprehensive documentation.
9//!
10//! ## Representation: 2 bits per dimension
11//! - 00 = Z (0)
12//! - 01 = P (+1)
13//! - 10 = N (-1)
14//! - 11 = unused (treated as Z)
15//!
16//! ## Key Features
17//!
18//! - **Word-level parallelism**: 32 trits per u64 word
19//! - **Bitplane separation**: Even bits (P plane), odd bits (N plane)
20//! - **SIMD-friendly**: Branchless operations, vectorization-ready
21//! - **GPU-ready**: Coalesced memory access, no divergence
22//!
23//! ## Performance
24//!
25//! - Dot product: 32 trits per word operation
26//! - Bind: Pure bitwise operations (AND, OR, shift)
27//! - Bundle: Saturating addition via bitwise logic
28//! - Memory: 2 bits per trit (optimal for ternary encoding)
29//!
30//! ## When to Use
31//!
32//! Use `PackedTritVec` when:
33//! - Vector density ≥ 25%
34//! - Performing bulk operations (dot, bind, bundle)
35//! - SIMD/GPU acceleration is needed
36//! - Throughput over latency
37//!
38//! Use `SparseVec` when:
39//! - Vector density < 25%
40//! - Random access patterns
41//! - Incremental construction
42//! - Memory is constrained
43//!
44//! ## Future Enhancements
45//!
46//! - Explicit SIMD (AVX2, AVX-512, NEON) - Phase 2
47//! - GPU acceleration (CUDA, OpenCL) - Phase 5
48//! - See `IMPLEMENTATION_PLAN.md` for roadmap
49
50use crate::ternary::Trit;
51use crate::vsa::SparseVec;
52
53// SIMD feature detection for runtime dispatch
54#[cfg(target_arch = "x86_64")]
55fn has_avx512_vpopcntdq() -> bool {
56    // Check for AVX-512F (foundation) + AVX-512VPOPCNTDQ (native popcount)
57    std::arch::is_x86_feature_detected!("avx512f")
58        && std::arch::is_x86_feature_detected!("avx512vpopcntdq")
59}
60
61#[cfg(target_arch = "x86_64")]
62fn has_avx2() -> bool {
63    std::arch::is_x86_feature_detected!("avx2")
64}
65
66#[derive(Clone, Debug, PartialEq, Eq)]
67pub struct PackedTritVec {
68    len: usize,
69    data: Vec<u64>,
70}
71
72impl PackedTritVec {
73    const MASK_EVEN_BITS: u64 = 0x5555_5555_5555_5555;
74
75    #[inline]
76    fn ensure_len_and_clear(&mut self, len: usize) {
77        self.len = len;
78        let words = Self::word_count_for_len(len);
79        if self.data.len() != words {
80            self.data.resize(words, 0u64);
81        }
82        self.data.fill(0u64);
83    }
84
85    pub fn new_zero(len: usize) -> Self {
86        let bits = len.saturating_mul(2);
87        let words = bits.div_ceil(64);
88        Self {
89            len,
90            data: vec![0u64; words],
91        }
92    }
93
94    #[inline]
95    fn word_count_for_len(len: usize) -> usize {
96        let bits = len.saturating_mul(2);
97        bits.div_ceil(64)
98    }
99
100    #[inline]
101    fn last_word_mask(len: usize) -> u64 {
102        let lanes_in_last = len % 32;
103        if lanes_in_last == 0 {
104            !0u64
105        } else {
106            let used_bits = lanes_in_last * 2;
107            if used_bits >= 64 {
108                !0u64
109            } else {
110                (1u64 << used_bits) - 1
111            }
112        }
113    }
114
115    pub fn len(&self) -> usize {
116        self.len
117    }
118
119    pub fn is_empty(&self) -> bool {
120        self.len == 0
121    }
122
123    #[inline]
124    fn word_bit_index(i: usize) -> (usize, usize) {
125        let bit = i * 2;
126        (bit / 64, bit % 64)
127    }
128
129    pub fn get(&self, i: usize) -> Trit {
130        if i >= self.len {
131            return Trit::Z;
132        }
133        let (word, bit) = Self::word_bit_index(i);
134        let w = self.data.get(word).copied().unwrap_or(0);
135        let v = (w >> bit) & 0b11;
136        match v {
137            0b01 => Trit::P,
138            0b10 => Trit::N,
139            _ => Trit::Z,
140        }
141    }
142
143    pub fn set(&mut self, i: usize, t: Trit) {
144        if i >= self.len {
145            return;
146        }
147        let (word, bit) = Self::word_bit_index(i);
148        if let Some(w) = self.data.get_mut(word) {
149            *w &= !(0b11u64 << bit);
150            let enc = match t {
151                Trit::Z => 0b00u64,
152                Trit::P => 0b01u64,
153                Trit::N => 0b10u64,
154            };
155            *w |= enc << bit;
156        }
157    }
158
159    pub fn from_sparsevec(vec: &SparseVec, len: usize) -> Self {
160        let mut out = Self::new_zero(len);
161        out.fill_from_sparsevec(vec, len);
162        out
163    }
164
165    /// Fill this packed vector from a SparseVec, reusing existing allocation.
166    ///
167    /// This is a hot-path helper for PackedTritVec operations to avoid repeated allocations.
168    pub fn fill_from_sparsevec(&mut self, vec: &SparseVec, len: usize) {
169        self.ensure_len_and_clear(len);
170
171        // Fast set: output is already zeroed, so we can OR lane encodings.
172        // 01 => P (+1) sets even bit; 10 => N (-1) sets odd bit.
173        for &idx in &vec.pos {
174            if idx < len {
175                let bit = idx * 2;
176                let word = bit / 64;
177                let shift = bit % 64;
178                self.data[word] |= 1u64 << shift;
179            }
180        }
181
182        for &idx in &vec.neg {
183            if idx < len {
184                let bit = idx * 2;
185                let word = bit / 64;
186                let shift = bit % 64;
187                self.data[word] |= 1u64 << (shift + 1);
188            }
189        }
190
191        if !self.data.is_empty() {
192            let last = self.data.len() - 1;
193            self.data[last] &= Self::last_word_mask(self.len);
194        }
195    }
196
197    pub fn to_sparsevec(&self) -> SparseVec {
198        let mut pos: Vec<usize> = Vec::new();
199        let mut neg: Vec<usize> = Vec::new();
200
201        // Word-wise extraction: each u64 holds 32 trits (2 bits each).
202        for (word_idx, &word_raw) in self.data.iter().enumerate() {
203            let mut word = word_raw;
204            if word_idx + 1 == self.data.len() {
205                word &= Self::last_word_mask(self.len);
206            }
207
208            // P lanes have the even bit set; N lanes have the odd bit set.
209            // Shift odd bits down to even positions to get per-lane masks.
210            let pos_bits = word & Self::MASK_EVEN_BITS;
211            let neg_bits = (word >> 1) & Self::MASK_EVEN_BITS;
212
213            // CRITICAL FIX: Detect conflicting trits (0b11 = both P and N set)
214            // This can occur if a SparseVec with overlapping pos/neg is converted to PackedTritVec
215            // When both bits are set, treat as 0 (cancel out) to maintain invariant
216            let conflict_bits = pos_bits & neg_bits;
217            let clean_pos = pos_bits & !conflict_bits;
218            let clean_neg = neg_bits & !conflict_bits;
219
220            // Extract indices for P lanes.
221            let mut m = clean_pos;
222            while m != 0 {
223                let tz = m.trailing_zeros() as usize;
224                let lane = tz / 2;
225                let idx = word_idx * 32 + lane;
226                if idx < self.len {
227                    pos.push(idx);
228                }
229                m &= m - 1;
230            }
231
232            // Extract indices for N lanes.
233            let mut n = clean_neg;
234            while n != 0 {
235                let tz = n.trailing_zeros() as usize;
236                let lane = tz / 2;
237                let idx = word_idx * 32 + lane;
238                if idx < self.len {
239                    neg.push(idx);
240                }
241                n &= n - 1;
242            }
243        }
244
245        SparseVec { pos, neg }
246    }
247
248    /// Sparse ternary dot product: sum over i of a_i * b_i.
249    ///
250    /// Automatically dispatches to the best available SIMD implementation:
251    /// - AVX-512 with VPOPCNTDQ (servers, workstations)
252    /// - AVX2 (most modern x86_64, including Intel 14th gen)
253    /// - Scalar fallback (auto-vectorizes well with LLVM)
254    pub fn dot(&self, other: &Self) -> i32 {
255        let n = self.len.min(other.len);
256        if n == 0 {
257            return 0;
258        }
259
260        #[cfg(target_arch = "x86_64")]
261        {
262            let words = Self::word_count_for_len(n)
263                .min(self.data.len())
264                .min(other.data.len());
265
266            // AVX-512 with native popcount (best, but rare on consumer CPUs)
267            if words >= 16 && has_avx512_vpopcntdq() {
268                return unsafe { self.dot_avx512(other, n) };
269            }
270
271            // AVX2 (available on 14700K and most modern CPUs)
272            if words >= 8 && has_avx2() {
273                return unsafe { self.dot_avx2(other, n) };
274            }
275        }
276
277        // Scalar fallback (auto-vectorizes well)
278        self.dot_scalar(other, n)
279    }
280
281    /// Scalar implementation of dot product (auto-vectorizes well with LLVM).
282    #[inline]
283    fn dot_scalar(&self, other: &Self, n: usize) -> i32 {
284        let words = Self::word_count_for_len(n)
285            .min(self.data.len())
286            .min(other.data.len());
287
288        let mut acc: i32 = 0;
289        for w in 0..words {
290            let mut a = self.data[w];
291            let mut b = other.data[w];
292            if w + 1 == words {
293                let mask = Self::last_word_mask(n);
294                a &= mask;
295                b &= mask;
296            }
297
298            let a_pos = a & Self::MASK_EVEN_BITS;
299            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
300            let b_pos = b & Self::MASK_EVEN_BITS;
301            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
302
303            let pp = (a_pos & b_pos).count_ones() as i32;
304            let nn = (a_neg & b_neg).count_ones() as i32;
305            let pn = (a_pos & b_neg).count_ones() as i32;
306            let np = (a_neg & b_pos).count_ones() as i32;
307
308            acc += (pp + nn) - (pn + np);
309        }
310
311        acc
312    }
313
314    /// AVX-512 accelerated dot product using native VPOPCNTDQ.
315    ///
316    /// Processes 8 × u64 = 256 trits per iteration.
317    ///
318    /// # Safety
319    /// Caller must ensure AVX-512F and AVX-512VPOPCNTDQ are available.
320    #[cfg(target_arch = "x86_64")]
321    #[target_feature(enable = "avx512f", enable = "avx512vpopcntdq")]
322    unsafe fn dot_avx512(&self, other: &Self, n: usize) -> i32 {
323        use std::arch::x86_64::*;
324
325        let words = Self::word_count_for_len(n)
326            .min(self.data.len())
327            .min(other.data.len());
328
329        // Broadcast the even-bits mask to all 8 lanes
330        let mask_even = _mm512_set1_epi64(Self::MASK_EVEN_BITS as i64);
331
332        // Accumulators for positive and negative contributions
333        let mut acc_pos = _mm512_setzero_si512(); // pp + nn
334        let mut acc_neg = _mm512_setzero_si512(); // pn + np
335
336        // Process 8 words (256 trits) at a time
337        let chunks = words / 8;
338        for chunk in 0..chunks {
339            let base = chunk * 8;
340
341            // Load 8 × u64 from each vector
342            let va = _mm512_loadu_si512(self.data[base..].as_ptr() as *const __m512i);
343            let vb = _mm512_loadu_si512(other.data[base..].as_ptr() as *const __m512i);
344
345            // Separate bitplanes: even bits = P plane, odd bits = N plane
346            let a_pos = _mm512_and_si512(va, mask_even);
347            let a_neg = _mm512_and_si512(_mm512_srli_epi64(va, 1), mask_even);
348            let b_pos = _mm512_and_si512(vb, mask_even);
349            let b_neg = _mm512_and_si512(_mm512_srli_epi64(vb, 1), mask_even);
350
351            // Compute intersections
352            let pp = _mm512_and_si512(a_pos, b_pos); // Both positive
353            let nn = _mm512_and_si512(a_neg, b_neg); // Both negative
354            let pn = _mm512_and_si512(a_pos, b_neg); // Opposite signs
355            let np = _mm512_and_si512(a_neg, b_pos); // Opposite signs
356
357            // Native popcount on 64-bit lanes (AVX-512 VPOPCNTDQ)
358            let pp_cnt = _mm512_popcnt_epi64(pp);
359            let nn_cnt = _mm512_popcnt_epi64(nn);
360            let pn_cnt = _mm512_popcnt_epi64(pn);
361            let np_cnt = _mm512_popcnt_epi64(np);
362
363            // Accumulate: positive contributions (matching signs)
364            acc_pos = _mm512_add_epi64(acc_pos, pp_cnt);
365            acc_pos = _mm512_add_epi64(acc_pos, nn_cnt);
366
367            // Accumulate: negative contributions (opposing signs)
368            acc_neg = _mm512_add_epi64(acc_neg, pn_cnt);
369            acc_neg = _mm512_add_epi64(acc_neg, np_cnt);
370        }
371
372        // Horizontal reduction: sum all 8 lanes
373        let pos_sum = _mm512_reduce_add_epi64(acc_pos);
374        let neg_sum = _mm512_reduce_add_epi64(acc_neg);
375        let mut acc = (pos_sum - neg_sum) as i32;
376
377        // Handle remaining words with scalar
378        let remainder_start = chunks * 8;
379        for w in remainder_start..words {
380            let mut a = self.data[w];
381            let mut b = other.data[w];
382            if w + 1 == words {
383                let mask = Self::last_word_mask(n);
384                a &= mask;
385                b &= mask;
386            }
387
388            let a_pos = a & Self::MASK_EVEN_BITS;
389            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
390            let b_pos = b & Self::MASK_EVEN_BITS;
391            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
392
393            let pp = (a_pos & b_pos).count_ones() as i32;
394            let nn = (a_neg & b_neg).count_ones() as i32;
395            let pn = (a_pos & b_neg).count_ones() as i32;
396            let np = (a_neg & b_pos).count_ones() as i32;
397
398            acc += (pp + nn) - (pn + np);
399        }
400
401        acc
402    }
403
404    /// AVX2 accelerated dot product with popcount emulation.
405    ///
406    /// Processes 4 × u64 = 128 trits per iteration.
407    /// Uses the Wilkes-Wheeler-Gill algorithm for popcount via PSHUFB lookup.
408    ///
409    /// # Safety
410    /// Caller must ensure AVX2 is available.
411    #[cfg(target_arch = "x86_64")]
412    #[target_feature(enable = "avx2")]
413    unsafe fn dot_avx2(&self, other: &Self, n: usize) -> i32 {
414        use std::arch::x86_64::*;
415
416        let words = Self::word_count_for_len(n)
417            .min(self.data.len())
418            .min(other.data.len());
419
420        // Broadcast masks for bitplane separation
421        let mask_even = _mm256_set1_epi64x(Self::MASK_EVEN_BITS as i64);
422
423        // Nibble lookup table for popcount (0-15 -> 0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4)
424        let popcount_lut = _mm256_setr_epi8(
425            0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2,
426            3, 3, 4,
427        );
428        let low_nibble_mask = _mm256_set1_epi8(0x0F);
429
430        // Accumulators (use 64-bit to avoid overflow)
431        let mut acc_pos = _mm256_setzero_si256();
432        let mut acc_neg = _mm256_setzero_si256();
433
434        // Process 4 words (128 trits) at a time
435        let chunks = words / 4;
436        for chunk in 0..chunks {
437            let base = chunk * 4;
438
439            // Load 4 × u64 from each vector
440            let va = _mm256_loadu_si256(self.data[base..].as_ptr() as *const __m256i);
441            let vb = _mm256_loadu_si256(other.data[base..].as_ptr() as *const __m256i);
442
443            // Separate bitplanes: even bits = P plane, odd bits = N plane
444            let a_pos = _mm256_and_si256(va, mask_even);
445            let a_neg = _mm256_and_si256(_mm256_srli_epi64(va, 1), mask_even);
446            let b_pos = _mm256_and_si256(vb, mask_even);
447            let b_neg = _mm256_and_si256(_mm256_srli_epi64(vb, 1), mask_even);
448
449            // Compute intersections
450            let pp = _mm256_and_si256(a_pos, b_pos);
451            let nn = _mm256_and_si256(a_neg, b_neg);
452            let pn = _mm256_and_si256(a_pos, b_neg);
453            let np = _mm256_and_si256(a_neg, b_pos);
454
455            // Popcount using PSHUFB lookup (Wilkes-Wheeler-Gill algorithm)
456            let pp_lo = _mm256_shuffle_epi8(popcount_lut, _mm256_and_si256(pp, low_nibble_mask));
457            let pp_hi = _mm256_shuffle_epi8(
458                popcount_lut,
459                _mm256_and_si256(_mm256_srli_epi16(pp, 4), low_nibble_mask),
460            );
461            let nn_lo = _mm256_shuffle_epi8(popcount_lut, _mm256_and_si256(nn, low_nibble_mask));
462            let nn_hi = _mm256_shuffle_epi8(
463                popcount_lut,
464                _mm256_and_si256(_mm256_srli_epi16(nn, 4), low_nibble_mask),
465            );
466
467            let pn_lo = _mm256_shuffle_epi8(popcount_lut, _mm256_and_si256(pn, low_nibble_mask));
468            let pn_hi = _mm256_shuffle_epi8(
469                popcount_lut,
470                _mm256_and_si256(_mm256_srli_epi16(pn, 4), low_nibble_mask),
471            );
472            let np_lo = _mm256_shuffle_epi8(popcount_lut, _mm256_and_si256(np, low_nibble_mask));
473            let np_hi = _mm256_shuffle_epi8(
474                popcount_lut,
475                _mm256_and_si256(_mm256_srli_epi16(np, 4), low_nibble_mask),
476            );
477
478            // Sum byte counts
479            let pos_bytes =
480                _mm256_add_epi8(_mm256_add_epi8(pp_lo, pp_hi), _mm256_add_epi8(nn_lo, nn_hi));
481            let neg_bytes =
482                _mm256_add_epi8(_mm256_add_epi8(pn_lo, pn_hi), _mm256_add_epi8(np_lo, np_hi));
483
484            // Horizontal sum of bytes to 64-bit (SAD against zero)
485            let pos_sad = _mm256_sad_epu8(pos_bytes, _mm256_setzero_si256());
486            let neg_sad = _mm256_sad_epu8(neg_bytes, _mm256_setzero_si256());
487
488            acc_pos = _mm256_add_epi64(acc_pos, pos_sad);
489            acc_neg = _mm256_add_epi64(acc_neg, neg_sad);
490        }
491
492        // Horizontal reduction: sum all 4 × 64-bit lanes
493        let pos_lo = _mm256_castsi256_si128(acc_pos);
494        let pos_hi = _mm256_extracti128_si256(acc_pos, 1);
495        let pos_sum128 = _mm_add_epi64(pos_lo, pos_hi);
496
497        let neg_lo = _mm256_castsi256_si128(acc_neg);
498        let neg_hi = _mm256_extracti128_si256(acc_neg, 1);
499        let neg_sum128 = _mm_add_epi64(neg_lo, neg_hi);
500
501        let pos_final = _mm_extract_epi64(pos_sum128, 0) + _mm_extract_epi64(pos_sum128, 1);
502        let neg_final = _mm_extract_epi64(neg_sum128, 0) + _mm_extract_epi64(neg_sum128, 1);
503
504        let mut acc = (pos_final - neg_final) as i32;
505
506        // Handle remaining words with scalar
507        let remainder_start = chunks * 4;
508        for w in remainder_start..words {
509            let mut a = self.data[w];
510            let mut b = other.data[w];
511            if w + 1 == words {
512                let mask = Self::last_word_mask(n);
513                a &= mask;
514                b &= mask;
515            }
516
517            let a_pos = a & Self::MASK_EVEN_BITS;
518            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
519            let b_pos = b & Self::MASK_EVEN_BITS;
520            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
521
522            let pp = (a_pos & b_pos).count_ones() as i32;
523            let nn = (a_neg & b_neg).count_ones() as i32;
524            let pn = (a_pos & b_neg).count_ones() as i32;
525            let np = (a_neg & b_pos).count_ones() as i32;
526
527            acc += (pp + nn) - (pn + np);
528        }
529
530        acc
531    }
532
533    /// Element-wise ternary multiplication (bind primitive).
534    ///
535    /// Automatically dispatches to the best available SIMD implementation:
536    /// - AVX-512F (servers, workstations)
537    /// - AVX2 (most modern x86_64)
538    /// - Scalar fallback
539    pub fn bind(&self, other: &Self) -> Self {
540        let n = self.len.min(other.len);
541        if n == 0 {
542            return Self::new_zero(0);
543        }
544
545        let words = Self::word_count_for_len(n)
546            .min(self.data.len())
547            .min(other.data.len());
548        let mut out = Self::new_zero(n);
549
550        #[cfg(target_arch = "x86_64")]
551        {
552            // AVX-512F (processes 8 words = 256 trits per iteration)
553            if words >= 8 && std::arch::is_x86_feature_detected!("avx512f") {
554                unsafe { self.bind_avx512(other, n, &mut out) };
555                return out;
556            }
557
558            // AVX2 (processes 4 words = 128 trits per iteration)
559            if words >= 4 && has_avx2() {
560                unsafe { self.bind_avx2(other, n, &mut out) };
561                return out;
562            }
563        }
564
565        // Scalar fallback
566        self.bind_scalar(other, n, &mut out);
567        out
568    }
569
570    /// Scalar bind implementation
571    #[inline]
572    fn bind_scalar(&self, other: &Self, n: usize, out: &mut Self) {
573        let words = Self::word_count_for_len(n)
574            .min(self.data.len())
575            .min(other.data.len())
576            .min(out.data.len());
577
578        for w in 0..words {
579            let mut a = self.data[w];
580            let mut b = other.data[w];
581            if w + 1 == words {
582                let mask = Self::last_word_mask(n);
583                a &= mask;
584                b &= mask;
585            }
586
587            let a_pos = a & Self::MASK_EVEN_BITS;
588            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
589            let b_pos = b & Self::MASK_EVEN_BITS;
590            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
591
592            let same = (a_pos & b_pos) | (a_neg & b_neg);
593            let opp = (a_pos & b_neg) | (a_neg & b_pos);
594
595            out.data[w] = same | (opp << 1);
596        }
597
598        // Ensure any unused tail bits stay zero.
599        if !out.data.is_empty() {
600            let last = out.data.len() - 1;
601            out.data[last] &= Self::last_word_mask(out.len);
602        }
603    }
604
605    /// AVX-512 accelerated bind operation.
606    ///
607    /// Processes 8 × u64 = 256 trits per iteration using 512-bit registers.
608    ///
609    /// # Safety
610    /// Caller must ensure AVX-512F is available.
611    #[cfg(target_arch = "x86_64")]
612    #[target_feature(enable = "avx512f")]
613    unsafe fn bind_avx512(&self, other: &Self, n: usize, out: &mut Self) {
614        use std::arch::x86_64::*;
615
616        let words = Self::word_count_for_len(n)
617            .min(self.data.len())
618            .min(other.data.len())
619            .min(out.data.len());
620
621        let mask_even = _mm512_set1_epi64(Self::MASK_EVEN_BITS as i64);
622
623        // Process 8 words at a time
624        let chunks = words / 8;
625        for chunk in 0..chunks {
626            let base = chunk * 8;
627
628            let va = _mm512_loadu_si512(self.data[base..].as_ptr() as *const __m512i);
629            let vb = _mm512_loadu_si512(other.data[base..].as_ptr() as *const __m512i);
630
631            // Separate bitplanes
632            let a_pos = _mm512_and_si512(va, mask_even);
633            let a_neg = _mm512_and_si512(_mm512_srli_epi64(va, 1), mask_even);
634            let b_pos = _mm512_and_si512(vb, mask_even);
635            let b_neg = _mm512_and_si512(_mm512_srli_epi64(vb, 1), mask_even);
636
637            // Bind: same signs -> P, opposite signs -> N
638            let same = _mm512_or_si512(
639                _mm512_and_si512(a_pos, b_pos),
640                _mm512_and_si512(a_neg, b_neg),
641            );
642            let opp = _mm512_or_si512(
643                _mm512_and_si512(a_pos, b_neg),
644                _mm512_and_si512(a_neg, b_pos),
645            );
646
647            let result = _mm512_or_si512(same, _mm512_slli_epi64(opp, 1));
648            _mm512_storeu_si512(out.data[base..].as_mut_ptr() as *mut __m512i, result);
649        }
650
651        // Handle remaining words with scalar
652        let remainder_start = chunks * 8;
653        for w in remainder_start..words {
654            let mut a = self.data[w];
655            let mut b = other.data[w];
656            if w + 1 == words {
657                let mask = Self::last_word_mask(n);
658                a &= mask;
659                b &= mask;
660            }
661
662            let a_pos = a & Self::MASK_EVEN_BITS;
663            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
664            let b_pos = b & Self::MASK_EVEN_BITS;
665            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
666
667            let same = (a_pos & b_pos) | (a_neg & b_neg);
668            let opp = (a_pos & b_neg) | (a_neg & b_pos);
669
670            out.data[w] = same | (opp << 1);
671        }
672
673        if !out.data.is_empty() {
674            let last = out.data.len() - 1;
675            out.data[last] &= Self::last_word_mask(out.len);
676        }
677    }
678
679    /// AVX2 accelerated bind operation.
680    ///
681    /// Processes 4 × u64 = 128 trits per iteration using 256-bit registers.
682    ///
683    /// # Safety
684    /// Caller must ensure AVX2 is available.
685    #[cfg(target_arch = "x86_64")]
686    #[target_feature(enable = "avx2")]
687    unsafe fn bind_avx2(&self, other: &Self, n: usize, out: &mut Self) {
688        use std::arch::x86_64::*;
689
690        let words = Self::word_count_for_len(n)
691            .min(self.data.len())
692            .min(other.data.len())
693            .min(out.data.len());
694
695        let mask_even = _mm256_set1_epi64x(Self::MASK_EVEN_BITS as i64);
696
697        // Process 4 words at a time
698        let chunks = words / 4;
699        for chunk in 0..chunks {
700            let base = chunk * 4;
701
702            let va = _mm256_loadu_si256(self.data[base..].as_ptr() as *const __m256i);
703            let vb = _mm256_loadu_si256(other.data[base..].as_ptr() as *const __m256i);
704
705            // Separate bitplanes
706            let a_pos = _mm256_and_si256(va, mask_even);
707            let a_neg = _mm256_and_si256(_mm256_srli_epi64(va, 1), mask_even);
708            let b_pos = _mm256_and_si256(vb, mask_even);
709            let b_neg = _mm256_and_si256(_mm256_srli_epi64(vb, 1), mask_even);
710
711            // Bind: same signs -> P, opposite signs -> N
712            let same = _mm256_or_si256(
713                _mm256_and_si256(a_pos, b_pos),
714                _mm256_and_si256(a_neg, b_neg),
715            );
716            let opp = _mm256_or_si256(
717                _mm256_and_si256(a_pos, b_neg),
718                _mm256_and_si256(a_neg, b_pos),
719            );
720
721            let result = _mm256_or_si256(same, _mm256_slli_epi64(opp, 1));
722            _mm256_storeu_si256(out.data[base..].as_mut_ptr() as *mut __m256i, result);
723        }
724
725        // Handle remaining words with scalar
726        let remainder_start = chunks * 4;
727        for w in remainder_start..words {
728            let mut a = self.data[w];
729            let mut b = other.data[w];
730            if w + 1 == words {
731                let mask = Self::last_word_mask(n);
732                a &= mask;
733                b &= mask;
734            }
735
736            let a_pos = a & Self::MASK_EVEN_BITS;
737            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
738            let b_pos = b & Self::MASK_EVEN_BITS;
739            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
740
741            let same = (a_pos & b_pos) | (a_neg & b_neg);
742            let opp = (a_pos & b_neg) | (a_neg & b_pos);
743
744            out.data[w] = same | (opp << 1);
745        }
746
747        if !out.data.is_empty() {
748            let last = out.data.len() - 1;
749            out.data[last] &= Self::last_word_mask(out.len);
750        }
751    }
752
753    /// Element-wise ternary multiplication into an existing output buffer.
754    pub fn bind_into(&self, other: &Self, out: &mut Self) {
755        let n = self.len.min(other.len);
756        out.ensure_len_and_clear(n);
757        if n == 0 {
758            return;
759        }
760
761        let words = Self::word_count_for_len(n)
762            .min(self.data.len())
763            .min(other.data.len())
764            .min(out.data.len());
765
766        for w in 0..words {
767            let mut a = self.data[w];
768            let mut b = other.data[w];
769            if w + 1 == words {
770                let mask = Self::last_word_mask(n);
771                a &= mask;
772                b &= mask;
773            }
774
775            let a_pos = a & Self::MASK_EVEN_BITS;
776            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
777            let b_pos = b & Self::MASK_EVEN_BITS;
778            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
779
780            let same = (a_pos & b_pos) | (a_neg & b_neg);
781            let opp = (a_pos & b_neg) | (a_neg & b_pos);
782
783            out.data[w] = same | (opp << 1);
784        }
785
786        if !out.data.is_empty() {
787            let last = out.data.len() - 1;
788            out.data[last] &= Self::last_word_mask(out.len);
789        }
790    }
791
792    /// Element-wise saturating ternary addition (bundle primitive for two vectors).
793    ///
794    /// Automatically dispatches to the best available SIMD implementation:
795    /// - AVX-512F (servers, workstations)
796    /// - AVX2 (most modern x86_64)
797    /// - Scalar fallback
798    pub fn bundle(&self, other: &Self) -> Self {
799        let n = self.len.min(other.len);
800        if n == 0 {
801            return Self::new_zero(0);
802        }
803
804        let words = Self::word_count_for_len(n)
805            .min(self.data.len())
806            .min(other.data.len());
807        let mut out = Self::new_zero(n);
808
809        #[cfg(target_arch = "x86_64")]
810        {
811            // AVX-512F (processes 8 words = 256 trits per iteration)
812            if words >= 8 && std::arch::is_x86_feature_detected!("avx512f") {
813                unsafe { self.bundle_avx512(other, n, &mut out) };
814                return out;
815            }
816
817            // AVX2 (processes 4 words = 128 trits per iteration)
818            if words >= 4 && has_avx2() {
819                unsafe { self.bundle_avx2(other, n, &mut out) };
820                return out;
821            }
822        }
823
824        // Scalar fallback
825        self.bundle_scalar(other, n, &mut out);
826        out
827    }
828
829    /// Scalar bundle implementation
830    #[inline]
831    fn bundle_scalar(&self, other: &Self, n: usize, out: &mut Self) {
832        let words = Self::word_count_for_len(n)
833            .min(self.data.len())
834            .min(other.data.len())
835            .min(out.data.len());
836
837        for w in 0..words {
838            let mut a = self.data[w];
839            let mut b = other.data[w];
840            if w + 1 == words {
841                let mask = Self::last_word_mask(n);
842                a &= mask;
843                b &= mask;
844            }
845
846            let a_pos = a & Self::MASK_EVEN_BITS;
847            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
848            let b_pos = b & Self::MASK_EVEN_BITS;
849            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
850
851            let mask = Self::MASK_EVEN_BITS;
852            let not_b_neg = (!b_neg) & mask;
853            let not_a_neg = (!a_neg) & mask;
854            let not_b_pos = (!b_pos) & mask;
855            let not_a_pos = (!a_pos) & mask;
856
857            let pos = (a_pos & not_b_neg) | (b_pos & not_a_neg);
858            let neg = (a_neg & not_b_pos) | (b_neg & not_a_pos);
859
860            out.data[w] = pos | (neg << 1);
861        }
862
863        if !out.data.is_empty() {
864            let last = out.data.len() - 1;
865            out.data[last] &= Self::last_word_mask(out.len);
866        }
867    }
868
869    /// AVX-512 accelerated bundle operation.
870    ///
871    /// Processes 8 × u64 = 256 trits per iteration using 512-bit registers.
872    ///
873    /// # Safety
874    /// Caller must ensure AVX-512F is available.
875    #[cfg(target_arch = "x86_64")]
876    #[target_feature(enable = "avx512f")]
877    unsafe fn bundle_avx512(&self, other: &Self, n: usize, out: &mut Self) {
878        use std::arch::x86_64::*;
879
880        let words = Self::word_count_for_len(n)
881            .min(self.data.len())
882            .min(other.data.len())
883            .min(out.data.len());
884
885        let mask_even = _mm512_set1_epi64(Self::MASK_EVEN_BITS as i64);
886
887        // Process 8 words at a time
888        let chunks = words / 8;
889        for chunk in 0..chunks {
890            let base = chunk * 8;
891
892            let va = _mm512_loadu_si512(self.data[base..].as_ptr() as *const __m512i);
893            let vb = _mm512_loadu_si512(other.data[base..].as_ptr() as *const __m512i);
894
895            // Separate bitplanes
896            let a_pos = _mm512_and_si512(va, mask_even);
897            let a_neg = _mm512_and_si512(_mm512_srli_epi64(va, 1), mask_even);
898            let b_pos = _mm512_and_si512(vb, mask_even);
899            let b_neg = _mm512_and_si512(_mm512_srli_epi64(vb, 1), mask_even);
900
901            // Bundle: saturating add
902            // pos = (a_pos & !b_neg) | (b_pos & !a_neg)
903            // neg = (a_neg & !b_pos) | (b_neg & !a_pos)
904            let not_b_neg = _mm512_andnot_si512(b_neg, mask_even);
905            let not_a_neg = _mm512_andnot_si512(a_neg, mask_even);
906            let not_b_pos = _mm512_andnot_si512(b_pos, mask_even);
907            let not_a_pos = _mm512_andnot_si512(a_pos, mask_even);
908
909            let pos = _mm512_or_si512(
910                _mm512_and_si512(a_pos, not_b_neg),
911                _mm512_and_si512(b_pos, not_a_neg),
912            );
913            let neg = _mm512_or_si512(
914                _mm512_and_si512(a_neg, not_b_pos),
915                _mm512_and_si512(b_neg, not_a_pos),
916            );
917
918            let result = _mm512_or_si512(pos, _mm512_slli_epi64(neg, 1));
919            _mm512_storeu_si512(out.data[base..].as_mut_ptr() as *mut __m512i, result);
920        }
921
922        // Handle remaining words with scalar
923        let remainder_start = chunks * 8;
924        for w in remainder_start..words {
925            let mut a = self.data[w];
926            let mut b = other.data[w];
927            if w + 1 == words {
928                let mask = Self::last_word_mask(n);
929                a &= mask;
930                b &= mask;
931            }
932
933            let a_pos = a & Self::MASK_EVEN_BITS;
934            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
935            let b_pos = b & Self::MASK_EVEN_BITS;
936            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
937
938            let mask = Self::MASK_EVEN_BITS;
939            let not_b_neg = (!b_neg) & mask;
940            let not_a_neg = (!a_neg) & mask;
941            let not_b_pos = (!b_pos) & mask;
942            let not_a_pos = (!a_pos) & mask;
943
944            let pos = (a_pos & not_b_neg) | (b_pos & not_a_neg);
945            let neg = (a_neg & not_b_pos) | (b_neg & not_a_pos);
946
947            out.data[w] = pos | (neg << 1);
948        }
949
950        if !out.data.is_empty() {
951            let last = out.data.len() - 1;
952            out.data[last] &= Self::last_word_mask(out.len);
953        }
954    }
955
956    /// AVX2 accelerated bundle operation.
957    ///
958    /// Processes 4 × u64 = 128 trits per iteration using 256-bit registers.
959    ///
960    /// # Safety
961    /// Caller must ensure AVX2 is available.
962    #[cfg(target_arch = "x86_64")]
963    #[target_feature(enable = "avx2")]
964    unsafe fn bundle_avx2(&self, other: &Self, n: usize, out: &mut Self) {
965        use std::arch::x86_64::*;
966
967        let words = Self::word_count_for_len(n)
968            .min(self.data.len())
969            .min(other.data.len())
970            .min(out.data.len());
971
972        let mask_even = _mm256_set1_epi64x(Self::MASK_EVEN_BITS as i64);
973
974        // Process 4 words at a time
975        let chunks = words / 4;
976        for chunk in 0..chunks {
977            let base = chunk * 4;
978
979            let va = _mm256_loadu_si256(self.data[base..].as_ptr() as *const __m256i);
980            let vb = _mm256_loadu_si256(other.data[base..].as_ptr() as *const __m256i);
981
982            // Separate bitplanes
983            let a_pos = _mm256_and_si256(va, mask_even);
984            let a_neg = _mm256_and_si256(_mm256_srli_epi64(va, 1), mask_even);
985            let b_pos = _mm256_and_si256(vb, mask_even);
986            let b_neg = _mm256_and_si256(_mm256_srli_epi64(vb, 1), mask_even);
987
988            // Bundle: saturating add
989            let not_b_neg = _mm256_andnot_si256(b_neg, mask_even);
990            let not_a_neg = _mm256_andnot_si256(a_neg, mask_even);
991            let not_b_pos = _mm256_andnot_si256(b_pos, mask_even);
992            let not_a_pos = _mm256_andnot_si256(a_pos, mask_even);
993
994            let pos = _mm256_or_si256(
995                _mm256_and_si256(a_pos, not_b_neg),
996                _mm256_and_si256(b_pos, not_a_neg),
997            );
998            let neg = _mm256_or_si256(
999                _mm256_and_si256(a_neg, not_b_pos),
1000                _mm256_and_si256(b_neg, not_a_pos),
1001            );
1002
1003            let result = _mm256_or_si256(pos, _mm256_slli_epi64(neg, 1));
1004            _mm256_storeu_si256(out.data[base..].as_mut_ptr() as *mut __m256i, result);
1005        }
1006
1007        // Handle remaining words with scalar
1008        let remainder_start = chunks * 4;
1009        for w in remainder_start..words {
1010            let mut a = self.data[w];
1011            let mut b = other.data[w];
1012            if w + 1 == words {
1013                let mask = Self::last_word_mask(n);
1014                a &= mask;
1015                b &= mask;
1016            }
1017
1018            let a_pos = a & Self::MASK_EVEN_BITS;
1019            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
1020            let b_pos = b & Self::MASK_EVEN_BITS;
1021            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
1022
1023            let mask = Self::MASK_EVEN_BITS;
1024            let not_b_neg = (!b_neg) & mask;
1025            let not_a_neg = (!a_neg) & mask;
1026            let not_b_pos = (!b_pos) & mask;
1027            let not_a_pos = (!a_pos) & mask;
1028
1029            let pos = (a_pos & not_b_neg) | (b_pos & not_a_neg);
1030            let neg = (a_neg & not_b_pos) | (b_neg & not_a_pos);
1031
1032            out.data[w] = pos | (neg << 1);
1033        }
1034
1035        if !out.data.is_empty() {
1036            let last = out.data.len() - 1;
1037            out.data[last] &= Self::last_word_mask(out.len);
1038        }
1039    }
1040
1041    /// Element-wise saturating ternary addition into an existing output buffer.
1042    pub fn bundle_into(&self, other: &Self, out: &mut Self) {
1043        let n = self.len.min(other.len);
1044        out.ensure_len_and_clear(n);
1045        if n == 0 {
1046            return;
1047        }
1048
1049        let words = Self::word_count_for_len(n)
1050            .min(self.data.len())
1051            .min(other.data.len())
1052            .min(out.data.len());
1053
1054        for w in 0..words {
1055            let mut a = self.data[w];
1056            let mut b = other.data[w];
1057            if w + 1 == words {
1058                let mask = Self::last_word_mask(n);
1059                a &= mask;
1060                b &= mask;
1061            }
1062
1063            let a_pos = a & Self::MASK_EVEN_BITS;
1064            let a_neg = (a >> 1) & Self::MASK_EVEN_BITS;
1065            let b_pos = b & Self::MASK_EVEN_BITS;
1066            let b_neg = (b >> 1) & Self::MASK_EVEN_BITS;
1067
1068            let mask = Self::MASK_EVEN_BITS;
1069            let not_b_neg = (!b_neg) & mask;
1070            let not_a_neg = (!a_neg) & mask;
1071            let not_b_pos = (!b_pos) & mask;
1072            let not_a_pos = (!a_pos) & mask;
1073
1074            let pos = (a_pos & not_b_neg) | (b_pos & not_a_neg);
1075            let neg = (a_neg & not_b_pos) | (b_neg & not_a_pos);
1076
1077            out.data[w] = pos | (neg << 1);
1078        }
1079
1080        if !out.data.is_empty() {
1081            let last = out.data.len() - 1;
1082            out.data[last] &= Self::last_word_mask(out.len);
1083        }
1084    }
1085}