Skip to main content

trit_vsa/
sparse.rs

1//! Sparse ternary vector storage using COO format.
2//!
3//! This module provides `SparseVec`, an efficient representation for highly
4//! sparse ternary vectors. It stores only non-zero indices and their signs.
5
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9use crate::error::{Result, TernaryError};
10use crate::packed::PackedTritVec;
11use crate::trit::Trit;
12
13/// A sparse ternary vector using COO (Coordinate) format.
14///
15/// Only non-zero values are stored, making this efficient for vectors where
16/// most elements are zero (high sparsity).
17///
18/// # Storage
19///
20/// Non-zero indices are stored separately for positive and negative values:
21/// - `positive_indices`: indices where value is +1
22/// - `negative_indices`: indices where value is -1
23///
24/// # When to Use
25///
26/// Use `SparseVec` when sparsity > 90% for memory efficiency.
27/// Use `PackedTritVec` for denser vectors or when operations like dot product
28/// need consistent O(n) time regardless of sparsity.
29///
30/// # Examples
31///
32/// ```
33/// use trit_vsa::{SparseVec, Trit};
34///
35/// let mut vec = SparseVec::new(1000);
36/// vec.set(10, Trit::P);
37/// vec.set(500, Trit::N);
38///
39/// assert_eq!(vec.get(10), Trit::P);
40/// assert_eq!(vec.get(500), Trit::N);
41/// assert_eq!(vec.get(0), Trit::Z);
42/// assert_eq!(vec.count_nonzero(), 2);
43/// ```
44#[derive(Clone, Serialize, Deserialize)]
45pub struct SparseVec {
46    /// Indices where value is +1 (sorted).
47    positive_indices: Vec<usize>,
48    /// Indices where value is -1 (sorted).
49    negative_indices: Vec<usize>,
50    /// Logical dimension count.
51    num_dims: usize,
52}
53
54impl SparseVec {
55    /// Create a new sparse vector with given dimension count.
56    ///
57    /// All values are initialized to zero (no storage needed).
58    #[must_use]
59    pub fn new(num_dims: usize) -> Self {
60        Self {
61            positive_indices: Vec::new(),
62            negative_indices: Vec::new(),
63            num_dims,
64        }
65    }
66
67    /// Create from separate index lists.
68    ///
69    /// # Arguments
70    ///
71    /// * `positive_indices` - Indices where value is +1
72    /// * `negative_indices` - Indices where value is -1
73    /// * `num_dims` - Logical dimension count
74    ///
75    /// # Errors
76    ///
77    /// Returns error if any index is out of bounds or if there are duplicates
78    /// across positive and negative lists.
79    pub fn from_indices(
80        mut positive_indices: Vec<usize>,
81        mut negative_indices: Vec<usize>,
82        num_dims: usize,
83    ) -> Result<Self> {
84        // Validate and sort
85        positive_indices.sort_unstable();
86        negative_indices.sort_unstable();
87
88        // Check bounds
89        if let Some(&max) = positive_indices.last() {
90            if max >= num_dims {
91                return Err(TernaryError::IndexOutOfBounds {
92                    index: max,
93                    size: num_dims,
94                });
95            }
96        }
97        if let Some(&max) = negative_indices.last() {
98            if max >= num_dims {
99                return Err(TernaryError::IndexOutOfBounds {
100                    index: max,
101                    size: num_dims,
102                });
103            }
104        }
105
106        // Check for overlap (same index can't be both positive and negative)
107        let mut pi = 0;
108        let mut ni = 0;
109        while pi < positive_indices.len() && ni < negative_indices.len() {
110            match positive_indices[pi].cmp(&negative_indices[ni]) {
111                std::cmp::Ordering::Equal => {
112                    return Err(TernaryError::InvalidValue(positive_indices[pi] as i32));
113                }
114                std::cmp::Ordering::Less => pi += 1,
115                std::cmp::Ordering::Greater => ni += 1,
116            }
117        }
118
119        Ok(Self {
120            positive_indices,
121            negative_indices,
122            num_dims,
123        })
124    }
125
126    /// Create from a slice of trits.
127    #[must_use]
128    pub fn from_trits(trits: &[Trit]) -> Self {
129        let mut positive_indices = Vec::new();
130        let mut negative_indices = Vec::new();
131
132        for (i, &trit) in trits.iter().enumerate() {
133            match trit {
134                Trit::P => positive_indices.push(i),
135                Trit::N => negative_indices.push(i),
136                Trit::Z => {}
137            }
138        }
139
140        Self {
141            positive_indices,
142            negative_indices,
143            num_dims: trits.len(),
144        }
145    }
146
147    /// Create from a [`PackedTritVec`].
148    #[must_use]
149    pub fn from_packed(packed: &PackedTritVec) -> Self {
150        let mut positive_indices = Vec::new();
151        let mut negative_indices = Vec::new();
152
153        for i in 0..packed.len() {
154            match packed.get(i) {
155                Trit::P => positive_indices.push(i),
156                Trit::N => negative_indices.push(i),
157                Trit::Z => {}
158            }
159        }
160
161        Self {
162            positive_indices,
163            negative_indices,
164            num_dims: packed.len(),
165        }
166    }
167
168    /// Get the number of logical dimensions.
169    #[must_use]
170    pub const fn len(&self) -> usize {
171        self.num_dims
172    }
173
174    /// Check if the vector is empty.
175    #[must_use]
176    pub const fn is_empty(&self) -> bool {
177        self.num_dims == 0
178    }
179
180    /// Set a dimension to a trit value.
181    ///
182    /// # Panics
183    ///
184    /// Panics if `dim >= len()`.
185    pub fn set(&mut self, dim: usize, value: Trit) {
186        assert!(dim < self.num_dims, "dimension out of bounds");
187
188        // Remove from current lists
189        self.positive_indices.retain(|&i| i != dim);
190        self.negative_indices.retain(|&i| i != dim);
191
192        // Add to appropriate list
193        match value {
194            Trit::P => {
195                let pos = self.positive_indices.partition_point(|&x| x < dim);
196                self.positive_indices.insert(pos, dim);
197            }
198            Trit::N => {
199                let pos = self.negative_indices.partition_point(|&x| x < dim);
200                self.negative_indices.insert(pos, dim);
201            }
202            Trit::Z => {} // Already removed
203        }
204    }
205
206    /// Get the trit value at a dimension.
207    ///
208    /// # Panics
209    ///
210    /// Panics if `dim >= len()`.
211    #[must_use]
212    pub fn get(&self, dim: usize) -> Trit {
213        assert!(dim < self.num_dims, "dimension out of bounds");
214
215        if self.positive_indices.binary_search(&dim).is_ok() {
216            Trit::P
217        } else if self.negative_indices.binary_search(&dim).is_ok() {
218            Trit::N
219        } else {
220            Trit::Z
221        }
222    }
223
224    /// Get the number of dimensions.
225    #[must_use]
226    pub fn num_dims(&self) -> usize {
227        self.num_dims
228    }
229
230    /// Count non-zero elements.
231    #[must_use]
232    pub fn count_nonzero(&self) -> usize {
233        self.positive_indices.len() + self.negative_indices.len()
234    }
235
236    /// Count positive (+1) elements.
237    #[must_use]
238    pub fn count_positive(&self) -> usize {
239        self.positive_indices.len()
240    }
241
242    /// Count negative (-1) elements.
243    #[must_use]
244    pub fn count_negative(&self) -> usize {
245        self.negative_indices.len()
246    }
247
248    /// Calculate sparsity (fraction of zeros).
249    #[must_use]
250    #[allow(clippy::cast_precision_loss)]
251    pub fn sparsity(&self) -> f32 {
252        if self.num_dims == 0 {
253            return 1.0;
254        }
255        1.0 - (self.count_nonzero() as f32 / self.num_dims as f32)
256    }
257
258    /// Compute dot product with another sparse vector.
259    ///
260    /// This is O(k1 + k2) where k1 and k2 are the number of non-zero elements.
261    ///
262    /// # Panics
263    ///
264    /// Panics if vectors have different dimensions.
265    #[must_use]
266    pub fn dot(&self, other: &SparseVec) -> i32 {
267        assert_eq!(
268            self.num_dims, other.num_dims,
269            "vectors must have same dimensions"
270        );
271
272        let mut result: i32 = 0;
273
274        // Count intersections between same-sign indices
275        result += Self::count_intersection(&self.positive_indices, &other.positive_indices) as i32;
276        result += Self::count_intersection(&self.negative_indices, &other.negative_indices) as i32;
277
278        // Subtract intersections between opposite-sign indices
279        result -= Self::count_intersection(&self.positive_indices, &other.negative_indices) as i32;
280        result -= Self::count_intersection(&self.negative_indices, &other.positive_indices) as i32;
281
282        result
283    }
284
285    /// Compute dot product with a packed vector.
286    ///
287    /// Efficient when this sparse vector has few non-zeros.
288    ///
289    /// # Panics
290    ///
291    /// Panics if vectors have different dimensions.
292    #[must_use]
293    pub fn dot_packed(&self, other: &PackedTritVec) -> i32 {
294        assert_eq!(
295            self.num_dims,
296            other.len(),
297            "vectors must have same dimensions"
298        );
299
300        let mut result: i32 = 0;
301
302        // Sum contributions from positive indices
303        for &idx in &self.positive_indices {
304            result += other.get(idx).value() as i32;
305        }
306
307        // Sum contributions from negative indices (note: we add negative of other's value)
308        for &idx in &self.negative_indices {
309            result -= other.get(idx).value() as i32;
310        }
311
312        result
313    }
314
315    /// Compute the sum of all elements.
316    #[must_use]
317    pub fn sum(&self) -> i32 {
318        self.positive_indices.len() as i32 - self.negative_indices.len() as i32
319    }
320
321    /// Return a negated copy.
322    #[must_use]
323    pub fn negated(&self) -> Self {
324        Self {
325            positive_indices: self.negative_indices.clone(),
326            negative_indices: self.positive_indices.clone(),
327            num_dims: self.num_dims,
328        }
329    }
330
331    /// Get reference to positive indices.
332    #[must_use]
333    pub fn positive_indices(&self) -> &[usize] {
334        &self.positive_indices
335    }
336
337    /// Get reference to negative indices.
338    #[must_use]
339    pub fn negative_indices(&self) -> &[usize] {
340        &self.negative_indices
341    }
342
343    /// Convert to a [`PackedTritVec`].
344    #[must_use]
345    pub fn to_packed(&self) -> PackedTritVec {
346        let mut packed = PackedTritVec::new(self.num_dims);
347        for &idx in &self.positive_indices {
348            packed.set(idx, Trit::P);
349        }
350        for &idx in &self.negative_indices {
351            packed.set(idx, Trit::N);
352        }
353        packed
354    }
355
356    /// Convert to a vector of trits.
357    #[must_use]
358    pub fn to_trits(&self) -> Vec<Trit> {
359        let mut result = vec![Trit::Z; self.num_dims];
360        for &idx in &self.positive_indices {
361            result[idx] = Trit::P;
362        }
363        for &idx in &self.negative_indices {
364            result[idx] = Trit::N;
365        }
366        result
367    }
368
369    /// Memory size in bytes (approximate).
370    #[must_use]
371    pub fn memory_bytes(&self) -> usize {
372        // Vec overhead + index storage
373        std::mem::size_of::<Self>()
374            + self.positive_indices.capacity() * std::mem::size_of::<usize>()
375            + self.negative_indices.capacity() * std::mem::size_of::<usize>()
376    }
377
378    // Internal: count intersection of two sorted lists
379    fn count_intersection(a: &[usize], b: &[usize]) -> usize {
380        let mut count = 0;
381        let mut ai = 0;
382        let mut bi = 0;
383
384        while ai < a.len() && bi < b.len() {
385            match a[ai].cmp(&b[bi]) {
386                std::cmp::Ordering::Equal => {
387                    count += 1;
388                    ai += 1;
389                    bi += 1;
390                }
391                std::cmp::Ordering::Less => ai += 1,
392                std::cmp::Ordering::Greater => bi += 1,
393            }
394        }
395
396        count
397    }
398}
399
400impl fmt::Debug for SparseVec {
401    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402        write!(
403            f,
404            "SparseVec(dims={}, pos={}, neg={}, sparsity={:.2}%)",
405            self.num_dims,
406            self.positive_indices.len(),
407            self.negative_indices.len(),
408            self.sparsity() * 100.0
409        )
410    }
411}
412
413impl PartialEq for SparseVec {
414    fn eq(&self, other: &Self) -> bool {
415        self.num_dims == other.num_dims
416            && self.positive_indices == other.positive_indices
417            && self.negative_indices == other.negative_indices
418    }
419}
420
421impl Eq for SparseVec {}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_sparse_new() {
429        let vec = SparseVec::new(1000);
430        assert_eq!(vec.len(), 1000);
431        assert_eq!(vec.count_nonzero(), 0);
432        assert!((vec.sparsity() - 1.0).abs() < 0.001);
433    }
434
435    #[test]
436    fn test_sparse_set_get() {
437        let mut vec = SparseVec::new(100);
438
439        vec.set(10, Trit::P);
440        vec.set(20, Trit::N);
441        vec.set(50, Trit::P);
442
443        assert_eq!(vec.get(10), Trit::P);
444        assert_eq!(vec.get(20), Trit::N);
445        assert_eq!(vec.get(50), Trit::P);
446        assert_eq!(vec.get(0), Trit::Z);
447        assert_eq!(vec.get(99), Trit::Z);
448    }
449
450    #[test]
451    fn test_sparse_overwrite() {
452        let mut vec = SparseVec::new(10);
453
454        vec.set(0, Trit::P);
455        assert_eq!(vec.get(0), Trit::P);
456        assert_eq!(vec.count_nonzero(), 1);
457
458        vec.set(0, Trit::N);
459        assert_eq!(vec.get(0), Trit::N);
460        assert_eq!(vec.count_nonzero(), 1);
461
462        vec.set(0, Trit::Z);
463        assert_eq!(vec.get(0), Trit::Z);
464        assert_eq!(vec.count_nonzero(), 0);
465    }
466
467    #[test]
468    fn test_sparse_dot() {
469        let mut a = SparseVec::new(100);
470        let mut b = SparseVec::new(100);
471
472        // a = [+1 at 0, -1 at 1, +1 at 10]
473        a.set(0, Trit::P);
474        a.set(1, Trit::N);
475        a.set(10, Trit::P);
476
477        // b = [+1 at 0, +1 at 1, -1 at 20]
478        b.set(0, Trit::P);
479        b.set(1, Trit::P);
480        b.set(20, Trit::N);
481
482        // dot = 1*1 + (-1)*1 + 1*0 + 0*(-1) = 1 - 1 = 0
483        assert_eq!(a.dot(&b), 0);
484
485        // Modify b[1] to -1
486        b.set(1, Trit::N);
487        // dot = 1*1 + (-1)*(-1) + 1*0 + 0*(-1) = 1 + 1 = 2
488        assert_eq!(a.dot(&b), 2);
489    }
490
491    #[test]
492    fn test_sparse_dot_packed() {
493        let mut sparse = SparseVec::new(64);
494        let mut packed = PackedTritVec::new(64);
495
496        sparse.set(0, Trit::P);
497        sparse.set(1, Trit::N);
498
499        packed.set(0, Trit::P);
500        packed.set(1, Trit::P);
501        packed.set(2, Trit::N);
502
503        // dot = 1*1 + (-1)*1 = 0
504        assert_eq!(sparse.dot_packed(&packed), 0);
505
506        packed.set(1, Trit::N);
507        // dot = 1*1 + (-1)*(-1) = 2
508        assert_eq!(sparse.dot_packed(&packed), 2);
509    }
510
511    #[test]
512    fn test_sparse_from_trits() {
513        let trits = [Trit::P, Trit::N, Trit::Z, Trit::P, Trit::Z];
514        let vec = SparseVec::from_trits(&trits);
515
516        assert_eq!(vec.len(), 5);
517        assert_eq!(vec.count_positive(), 2);
518        assert_eq!(vec.count_negative(), 1);
519
520        assert_eq!(vec.to_trits(), trits);
521    }
522
523    #[test]
524    fn test_sparse_to_packed_roundtrip() {
525        let mut sparse = SparseVec::new(100);
526        sparse.set(0, Trit::P);
527        sparse.set(50, Trit::N);
528        sparse.set(99, Trit::P);
529
530        let packed = sparse.to_packed();
531        let back = SparseVec::from_packed(&packed);
532
533        assert_eq!(sparse, back);
534    }
535
536    #[test]
537    fn test_sparse_negated() {
538        let mut vec = SparseVec::new(10);
539        vec.set(0, Trit::P);
540        vec.set(1, Trit::N);
541
542        let neg = vec.negated();
543
544        assert_eq!(neg.get(0), Trit::N);
545        assert_eq!(neg.get(1), Trit::P);
546    }
547
548    #[test]
549    fn test_sparse_from_indices() {
550        let pos = vec![0, 10, 50];
551        let neg = vec![5, 20];
552        let vec = SparseVec::from_indices(pos, neg, 100).unwrap();
553
554        assert_eq!(vec.get(0), Trit::P);
555        assert_eq!(vec.get(10), Trit::P);
556        assert_eq!(vec.get(50), Trit::P);
557        assert_eq!(vec.get(5), Trit::N);
558        assert_eq!(vec.get(20), Trit::N);
559        assert_eq!(vec.get(1), Trit::Z);
560    }
561
562    #[test]
563    fn test_sparse_from_indices_overlap_error() {
564        let pos = vec![0, 10];
565        let neg = vec![10, 20]; // 10 is in both - invalid
566        let result = SparseVec::from_indices(pos, neg, 100);
567        assert!(result.is_err());
568    }
569
570    #[test]
571    fn test_sparse_from_indices_bounds_error() {
572        let pos = vec![100]; // Out of bounds for dim=100
573        let neg = vec![];
574        let result = SparseVec::from_indices(pos, neg, 100);
575        assert!(result.is_err());
576    }
577
578    #[test]
579    fn test_sparse_sum() {
580        let mut vec = SparseVec::new(100);
581        vec.set(0, Trit::P);
582        vec.set(1, Trit::P);
583        vec.set(2, Trit::N);
584
585        assert_eq!(vec.sum(), 1); // 1 + 1 - 1 = 1
586    }
587}