1use serde::{Deserialize, Serialize};
7use std::fmt;
8
9use crate::error::{Result, TernaryError};
10use crate::packed::PackedTritVec;
11use crate::trit::Trit;
12
13#[derive(Clone, Serialize, Deserialize)]
45pub struct SparseVec {
46 positive_indices: Vec<usize>,
48 negative_indices: Vec<usize>,
50 num_dims: usize,
52}
53
54impl SparseVec {
55 #[must_use]
59 pub fn new(num_dims: usize) -> Self {
60 Self {
61 positive_indices: Vec::new(),
62 negative_indices: Vec::new(),
63 num_dims,
64 }
65 }
66
67 pub fn from_indices(
80 mut positive_indices: Vec<usize>,
81 mut negative_indices: Vec<usize>,
82 num_dims: usize,
83 ) -> Result<Self> {
84 positive_indices.sort_unstable();
86 negative_indices.sort_unstable();
87
88 if let Some(&max) = positive_indices.last() {
90 if max >= num_dims {
91 return Err(TernaryError::IndexOutOfBounds {
92 index: max,
93 size: num_dims,
94 });
95 }
96 }
97 if let Some(&max) = negative_indices.last() {
98 if max >= num_dims {
99 return Err(TernaryError::IndexOutOfBounds {
100 index: max,
101 size: num_dims,
102 });
103 }
104 }
105
106 let mut pi = 0;
108 let mut ni = 0;
109 while pi < positive_indices.len() && ni < negative_indices.len() {
110 match positive_indices[pi].cmp(&negative_indices[ni]) {
111 std::cmp::Ordering::Equal => {
112 return Err(TernaryError::InvalidValue(positive_indices[pi] as i32));
113 }
114 std::cmp::Ordering::Less => pi += 1,
115 std::cmp::Ordering::Greater => ni += 1,
116 }
117 }
118
119 Ok(Self {
120 positive_indices,
121 negative_indices,
122 num_dims,
123 })
124 }
125
126 #[must_use]
128 pub fn from_trits(trits: &[Trit]) -> Self {
129 let mut positive_indices = Vec::new();
130 let mut negative_indices = Vec::new();
131
132 for (i, &trit) in trits.iter().enumerate() {
133 match trit {
134 Trit::P => positive_indices.push(i),
135 Trit::N => negative_indices.push(i),
136 Trit::Z => {}
137 }
138 }
139
140 Self {
141 positive_indices,
142 negative_indices,
143 num_dims: trits.len(),
144 }
145 }
146
147 #[must_use]
149 pub fn from_packed(packed: &PackedTritVec) -> Self {
150 let mut positive_indices = Vec::new();
151 let mut negative_indices = Vec::new();
152
153 for i in 0..packed.len() {
154 match packed.get(i) {
155 Trit::P => positive_indices.push(i),
156 Trit::N => negative_indices.push(i),
157 Trit::Z => {}
158 }
159 }
160
161 Self {
162 positive_indices,
163 negative_indices,
164 num_dims: packed.len(),
165 }
166 }
167
168 #[must_use]
170 pub const fn len(&self) -> usize {
171 self.num_dims
172 }
173
174 #[must_use]
176 pub const fn is_empty(&self) -> bool {
177 self.num_dims == 0
178 }
179
180 pub fn set(&mut self, dim: usize, value: Trit) {
186 assert!(dim < self.num_dims, "dimension out of bounds");
187
188 self.positive_indices.retain(|&i| i != dim);
190 self.negative_indices.retain(|&i| i != dim);
191
192 match value {
194 Trit::P => {
195 let pos = self.positive_indices.partition_point(|&x| x < dim);
196 self.positive_indices.insert(pos, dim);
197 }
198 Trit::N => {
199 let pos = self.negative_indices.partition_point(|&x| x < dim);
200 self.negative_indices.insert(pos, dim);
201 }
202 Trit::Z => {} }
204 }
205
206 #[must_use]
212 pub fn get(&self, dim: usize) -> Trit {
213 assert!(dim < self.num_dims, "dimension out of bounds");
214
215 if self.positive_indices.binary_search(&dim).is_ok() {
216 Trit::P
217 } else if self.negative_indices.binary_search(&dim).is_ok() {
218 Trit::N
219 } else {
220 Trit::Z
221 }
222 }
223
224 #[must_use]
226 pub fn count_nonzero(&self) -> usize {
227 self.positive_indices.len() + self.negative_indices.len()
228 }
229
230 #[must_use]
232 pub fn count_positive(&self) -> usize {
233 self.positive_indices.len()
234 }
235
236 #[must_use]
238 pub fn count_negative(&self) -> usize {
239 self.negative_indices.len()
240 }
241
242 #[must_use]
244 #[allow(clippy::cast_precision_loss)]
245 pub fn sparsity(&self) -> f32 {
246 if self.num_dims == 0 {
247 return 1.0;
248 }
249 1.0 - (self.count_nonzero() as f32 / self.num_dims as f32)
250 }
251
252 #[must_use]
260 pub fn dot(&self, other: &SparseVec) -> i32 {
261 assert_eq!(
262 self.num_dims, other.num_dims,
263 "vectors must have same dimensions"
264 );
265
266 let mut result: i32 = 0;
267
268 result += Self::count_intersection(&self.positive_indices, &other.positive_indices) as i32;
270 result += Self::count_intersection(&self.negative_indices, &other.negative_indices) as i32;
271
272 result -= Self::count_intersection(&self.positive_indices, &other.negative_indices) as i32;
274 result -= Self::count_intersection(&self.negative_indices, &other.positive_indices) as i32;
275
276 result
277 }
278
279 #[must_use]
287 pub fn dot_packed(&self, other: &PackedTritVec) -> i32 {
288 assert_eq!(
289 self.num_dims,
290 other.len(),
291 "vectors must have same dimensions"
292 );
293
294 let mut result: i32 = 0;
295
296 for &idx in &self.positive_indices {
298 result += other.get(idx).value() as i32;
299 }
300
301 for &idx in &self.negative_indices {
303 result -= other.get(idx).value() as i32;
304 }
305
306 result
307 }
308
309 #[must_use]
311 pub fn sum(&self) -> i32 {
312 self.positive_indices.len() as i32 - self.negative_indices.len() as i32
313 }
314
315 #[must_use]
317 pub fn negated(&self) -> Self {
318 Self {
319 positive_indices: self.negative_indices.clone(),
320 negative_indices: self.positive_indices.clone(),
321 num_dims: self.num_dims,
322 }
323 }
324
325 #[must_use]
327 pub fn positive_indices(&self) -> &[usize] {
328 &self.positive_indices
329 }
330
331 #[must_use]
333 pub fn negative_indices(&self) -> &[usize] {
334 &self.negative_indices
335 }
336
337 #[must_use]
339 pub fn to_packed(&self) -> PackedTritVec {
340 let mut packed = PackedTritVec::new(self.num_dims);
341 for &idx in &self.positive_indices {
342 packed.set(idx, Trit::P);
343 }
344 for &idx in &self.negative_indices {
345 packed.set(idx, Trit::N);
346 }
347 packed
348 }
349
350 #[must_use]
352 pub fn to_trits(&self) -> Vec<Trit> {
353 let mut result = vec![Trit::Z; self.num_dims];
354 for &idx in &self.positive_indices {
355 result[idx] = Trit::P;
356 }
357 for &idx in &self.negative_indices {
358 result[idx] = Trit::N;
359 }
360 result
361 }
362
363 #[must_use]
365 pub fn memory_bytes(&self) -> usize {
366 std::mem::size_of::<Self>()
368 + self.positive_indices.capacity() * std::mem::size_of::<usize>()
369 + self.negative_indices.capacity() * std::mem::size_of::<usize>()
370 }
371
372 fn count_intersection(a: &[usize], b: &[usize]) -> usize {
374 let mut count = 0;
375 let mut ai = 0;
376 let mut bi = 0;
377
378 while ai < a.len() && bi < b.len() {
379 match a[ai].cmp(&b[bi]) {
380 std::cmp::Ordering::Equal => {
381 count += 1;
382 ai += 1;
383 bi += 1;
384 }
385 std::cmp::Ordering::Less => ai += 1,
386 std::cmp::Ordering::Greater => bi += 1,
387 }
388 }
389
390 count
391 }
392}
393
394impl fmt::Debug for SparseVec {
395 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
396 write!(
397 f,
398 "SparseVec(dims={}, pos={}, neg={}, sparsity={:.2}%)",
399 self.num_dims,
400 self.positive_indices.len(),
401 self.negative_indices.len(),
402 self.sparsity() * 100.0
403 )
404 }
405}
406
407impl PartialEq for SparseVec {
408 fn eq(&self, other: &Self) -> bool {
409 self.num_dims == other.num_dims
410 && self.positive_indices == other.positive_indices
411 && self.negative_indices == other.negative_indices
412 }
413}
414
415impl Eq for SparseVec {}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_sparse_new() {
423 let vec = SparseVec::new(1000);
424 assert_eq!(vec.len(), 1000);
425 assert_eq!(vec.count_nonzero(), 0);
426 assert!((vec.sparsity() - 1.0).abs() < 0.001);
427 }
428
429 #[test]
430 fn test_sparse_set_get() {
431 let mut vec = SparseVec::new(100);
432
433 vec.set(10, Trit::P);
434 vec.set(20, Trit::N);
435 vec.set(50, Trit::P);
436
437 assert_eq!(vec.get(10), Trit::P);
438 assert_eq!(vec.get(20), Trit::N);
439 assert_eq!(vec.get(50), Trit::P);
440 assert_eq!(vec.get(0), Trit::Z);
441 assert_eq!(vec.get(99), Trit::Z);
442 }
443
444 #[test]
445 fn test_sparse_overwrite() {
446 let mut vec = SparseVec::new(10);
447
448 vec.set(0, Trit::P);
449 assert_eq!(vec.get(0), Trit::P);
450 assert_eq!(vec.count_nonzero(), 1);
451
452 vec.set(0, Trit::N);
453 assert_eq!(vec.get(0), Trit::N);
454 assert_eq!(vec.count_nonzero(), 1);
455
456 vec.set(0, Trit::Z);
457 assert_eq!(vec.get(0), Trit::Z);
458 assert_eq!(vec.count_nonzero(), 0);
459 }
460
461 #[test]
462 fn test_sparse_dot() {
463 let mut a = SparseVec::new(100);
464 let mut b = SparseVec::new(100);
465
466 a.set(0, Trit::P);
468 a.set(1, Trit::N);
469 a.set(10, Trit::P);
470
471 b.set(0, Trit::P);
473 b.set(1, Trit::P);
474 b.set(20, Trit::N);
475
476 assert_eq!(a.dot(&b), 0);
478
479 b.set(1, Trit::N);
481 assert_eq!(a.dot(&b), 2);
483 }
484
485 #[test]
486 fn test_sparse_dot_packed() {
487 let mut sparse = SparseVec::new(64);
488 let mut packed = PackedTritVec::new(64);
489
490 sparse.set(0, Trit::P);
491 sparse.set(1, Trit::N);
492
493 packed.set(0, Trit::P);
494 packed.set(1, Trit::P);
495 packed.set(2, Trit::N);
496
497 assert_eq!(sparse.dot_packed(&packed), 0);
499
500 packed.set(1, Trit::N);
501 assert_eq!(sparse.dot_packed(&packed), 2);
503 }
504
505 #[test]
506 fn test_sparse_from_trits() {
507 let trits = [Trit::P, Trit::N, Trit::Z, Trit::P, Trit::Z];
508 let vec = SparseVec::from_trits(&trits);
509
510 assert_eq!(vec.len(), 5);
511 assert_eq!(vec.count_positive(), 2);
512 assert_eq!(vec.count_negative(), 1);
513
514 assert_eq!(vec.to_trits(), trits);
515 }
516
517 #[test]
518 fn test_sparse_to_packed_roundtrip() {
519 let mut sparse = SparseVec::new(100);
520 sparse.set(0, Trit::P);
521 sparse.set(50, Trit::N);
522 sparse.set(99, Trit::P);
523
524 let packed = sparse.to_packed();
525 let back = SparseVec::from_packed(&packed);
526
527 assert_eq!(sparse, back);
528 }
529
530 #[test]
531 fn test_sparse_negated() {
532 let mut vec = SparseVec::new(10);
533 vec.set(0, Trit::P);
534 vec.set(1, Trit::N);
535
536 let neg = vec.negated();
537
538 assert_eq!(neg.get(0), Trit::N);
539 assert_eq!(neg.get(1), Trit::P);
540 }
541
542 #[test]
543 fn test_sparse_from_indices() {
544 let pos = vec![0, 10, 50];
545 let neg = vec![5, 20];
546 let vec = SparseVec::from_indices(pos, neg, 100).unwrap();
547
548 assert_eq!(vec.get(0), Trit::P);
549 assert_eq!(vec.get(10), Trit::P);
550 assert_eq!(vec.get(50), Trit::P);
551 assert_eq!(vec.get(5), Trit::N);
552 assert_eq!(vec.get(20), Trit::N);
553 assert_eq!(vec.get(1), Trit::Z);
554 }
555
556 #[test]
557 fn test_sparse_from_indices_overlap_error() {
558 let pos = vec![0, 10];
559 let neg = vec![10, 20]; let result = SparseVec::from_indices(pos, neg, 100);
561 assert!(result.is_err());
562 }
563
564 #[test]
565 fn test_sparse_from_indices_bounds_error() {
566 let pos = vec![100]; let neg = vec![];
568 let result = SparseVec::from_indices(pos, neg, 100);
569 assert!(result.is_err());
570 }
571
572 #[test]
573 fn test_sparse_sum() {
574 let mut vec = SparseVec::new(100);
575 vec.set(0, Trit::P);
576 vec.set(1, Trit::P);
577 vec.set(2, Trit::N);
578
579 assert_eq!(vec.sum(), 1); }
581}