mokosh/types/
sdr.rs

1//! Sparse Distributed Representation (SDR) implementation.
2//!
3//! An SDR is a data structure representing a group of boolean values (bits).
4//! It can be represented in three formats:
5//! - **Dense**: A contiguous array of all bits
6//! - **Sparse**: A sorted list of indices of active (true) bits
7//! - **Coordinate**: A list of coordinates for each dimension
8//!
9//! The SDR automatically converts between formats and caches results for efficiency.
10
11use crate::error::{MokoshError, Result};
12use crate::types::{ElemDense, ElemSparse, Real, UInt};
13use crate::utils::{simd, Random};
14
15use std::cell::RefCell;
16use std::fmt;
17
18/// Type alias for dense SDR data (array of bytes, 0 or 1).
19pub type SdrDense = Vec<ElemDense>;
20
21/// Type alias for sparse SDR data (sorted indices of active bits).
22pub type SdrSparse = Vec<ElemSparse>;
23
24/// Type alias for coordinate SDR data (coordinates per dimension).
25pub type SdrCoordinate = Vec<Vec<UInt>>;
26
27/// Callback function type for SDR value changes.
28pub type SdrCallback = Box<dyn Fn() + Send + Sync>;
29
30/// Internal cache state for lazy evaluation.
31#[derive(Default)]
32struct SdrCache {
33    dense: Option<SdrDense>,
34    sparse: Option<SdrSparse>,
35    coordinates: Option<SdrCoordinate>,
36}
37
38/// Sparse Distributed Representation.
39///
40/// This is the fundamental data structure in HTM. It represents a binary vector
41/// where typically only a small percentage of bits are active (true).
42///
43/// # Example
44///
45/// ```rust
46/// use mokosh::types::Sdr;
47///
48/// // Create a 10x10 SDR
49/// let mut sdr = Sdr::new(&[10, 10]);
50///
51/// // Set active bits using sparse indices
52/// sdr.set_sparse(&[1, 4, 8, 15, 42]).unwrap();
53///
54/// // Get the number of active bits
55/// assert_eq!(sdr.get_sum(), 5);
56///
57/// // Access in different formats
58/// let dense = sdr.get_dense();
59/// let sparse = sdr.get_sparse();
60/// let coords = sdr.get_coordinates();
61/// ```
62pub struct Sdr {
63    /// Dimensions of the SDR.
64    dimensions: Vec<UInt>,
65
66    /// Total size (product of dimensions).
67    size: usize,
68
69    /// Cached representations (interior mutability for lazy evaluation).
70    cache: RefCell<SdrCache>,
71
72    /// Callbacks to notify on value changes.
73    callbacks: RefCell<Vec<Option<SdrCallback>>>,
74
75    /// Callbacks to notify on destruction.
76    destroy_callbacks: RefCell<Vec<Option<SdrCallback>>>,
77}
78
79// Custom serialization for Sdr - we serialize dimensions and sparse indices.
80#[cfg(feature = "serde")]
81mod serde_impl {
82    use super::*;
83    use serde::{Deserialize, Deserializer, Serialize, Serializer};
84
85    #[derive(Serialize, Deserialize)]
86    struct SdrState {
87        dimensions: Vec<UInt>,
88        sparse: Vec<ElemSparse>,
89    }
90
91    impl Serialize for Sdr {
92        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
93        where
94            S: Serializer,
95        {
96            let state = SdrState {
97                dimensions: self.dimensions.clone(),
98                sparse: self.get_sparse().to_vec(),
99            };
100            state.serialize(serializer)
101        }
102    }
103
104    impl<'de> Deserialize<'de> for Sdr {
105        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
106        where
107            D: Deserializer<'de>,
108        {
109            let state = SdrState::deserialize(deserializer)?;
110            let mut sdr = Sdr::new(&state.dimensions);
111            sdr.set_sparse_unchecked(state.sparse);
112            Ok(sdr)
113        }
114    }
115}
116
117impl Sdr {
118    /// Creates a new SDR with the given dimensions, initialized to all zeros.
119    ///
120    /// # Arguments
121    ///
122    /// * `dimensions` - The shape of the SDR (e.g., `&[10, 10]` for 10x10)
123    ///
124    /// # Panics
125    ///
126    /// Panics if dimensions is empty or contains zeros.
127    ///
128    /// # Example
129    ///
130    /// ```rust
131    /// use mokosh::types::Sdr;
132    ///
133    /// let sdr = Sdr::new(&[100]);        // 1D SDR with 100 bits
134    /// let sdr2 = Sdr::new(&[10, 10]);    // 2D SDR with 100 bits
135    /// let sdr3 = Sdr::new(&[5, 4, 5]);   // 3D SDR with 100 bits
136    /// ```
137    #[must_use]
138    pub fn new(dimensions: &[UInt]) -> Self {
139        assert!(!dimensions.is_empty(), "Dimensions cannot be empty");
140
141        let size: usize = dimensions.iter().map(|&d| d as usize).product();
142
143        // Allow size 0 for placeholder SDRs
144        for (i, &dim) in dimensions.iter().enumerate() {
145            if dim == 0 && dimensions.len() > 1 {
146                panic!("Dimension {} cannot be zero in multi-dimensional SDR", i);
147            }
148        }
149
150        Self {
151            dimensions: dimensions.to_vec(),
152            size,
153            cache: RefCell::new(SdrCache::default()),
154            callbacks: RefCell::new(Vec::new()),
155            destroy_callbacks: RefCell::new(Vec::new()),
156        }
157    }
158
159    /// Creates a new SDR with dimensions initialized from an iterator.
160    pub fn with_dimensions<I>(dimensions: I) -> Self
161    where
162        I: IntoIterator<Item = UInt>,
163    {
164        let dims: Vec<UInt> = dimensions.into_iter().collect();
165        Self::new(&dims)
166    }
167
168    /// Returns the dimensions of this SDR.
169    #[inline]
170    #[must_use]
171    pub fn dimensions(&self) -> &[UInt] {
172        &self.dimensions
173    }
174
175    /// Returns the total number of bits in the SDR.
176    #[inline]
177    #[must_use]
178    pub fn size(&self) -> usize {
179        self.size
180    }
181
182    /// Returns the number of dimensions.
183    #[inline]
184    #[must_use]
185    pub fn num_dimensions(&self) -> usize {
186        self.dimensions.len()
187    }
188
189    /// Reshapes the SDR to new dimensions. The total size must remain the same.
190    ///
191    /// # Errors
192    ///
193    /// Returns an error if the new dimensions have a different total size.
194    pub fn reshape(&mut self, new_dimensions: &[UInt]) -> Result<()> {
195        let new_size: usize = new_dimensions.iter().map(|&d| d as usize).product();
196
197        if new_size != self.size {
198            return Err(MokoshError::InvalidDimensions(format!(
199                "Cannot reshape from size {} to size {}",
200                self.size, new_size
201            )));
202        }
203
204        self.dimensions = new_dimensions.to_vec();
205
206        // Invalidate coordinate cache as it depends on dimensions
207        self.cache.borrow_mut().coordinates = None;
208
209        Ok(())
210    }
211
212    /// Sets all bits to zero.
213    pub fn zero(&mut self) {
214        let mut cache = self.cache.borrow_mut();
215        cache.dense = Some(vec![0; self.size]);
216        cache.sparse = Some(Vec::new());
217        cache.coordinates = Some(vec![Vec::new(); self.dimensions.len()]);
218        drop(cache);
219
220        self.do_callbacks();
221    }
222
223    /// Clears all cached representations.
224    fn clear_cache(&self) {
225        let mut cache = self.cache.borrow_mut();
226        cache.dense = None;
227        cache.sparse = None;
228        cache.coordinates = None;
229    }
230
231    /// Invokes all registered callbacks.
232    fn do_callbacks(&self) {
233        let callbacks = self.callbacks.borrow();
234        for callback in callbacks.iter().flatten() {
235            callback();
236        }
237    }
238
239    // ========================================================================
240    // Dense format operations
241    // ========================================================================
242
243    /// Sets the SDR value from a dense array.
244    ///
245    /// # Arguments
246    ///
247    /// * `data` - A slice of values where non-zero means active
248    ///
249    /// # Errors
250    ///
251    /// Returns an error if the data length doesn't match the SDR size.
252    pub fn set_dense(&mut self, data: &[ElemDense]) -> Result<()> {
253        if data.len() != self.size {
254            return Err(MokoshError::DimensionMismatch {
255                expected: vec![self.size as u32],
256                actual: vec![data.len() as u32],
257            });
258        }
259
260        let mut cache = self.cache.borrow_mut();
261        cache.dense = Some(data.to_vec());
262        cache.sparse = None;
263        cache.coordinates = None;
264        drop(cache);
265
266        self.do_callbacks();
267        Ok(())
268    }
269
270    /// Sets the SDR value from a dense array, consuming it to avoid copying.
271    pub fn set_dense_owned(&mut self, data: SdrDense) -> Result<()> {
272        if data.len() != self.size {
273            return Err(MokoshError::DimensionMismatch {
274                expected: vec![self.size as u32],
275                actual: vec![data.len() as u32],
276            });
277        }
278
279        let mut cache = self.cache.borrow_mut();
280        cache.dense = Some(data);
281        cache.sparse = None;
282        cache.coordinates = None;
283        drop(cache);
284
285        self.do_callbacks();
286        Ok(())
287    }
288
289    /// Gets the dense representation of the SDR.
290    ///
291    /// This method lazily computes the dense array from sparse or coordinate
292    /// representations if needed.
293    #[must_use]
294    pub fn get_dense(&self) -> SdrDense {
295        {
296            let cache = self.cache.borrow();
297            if let Some(ref dense) = cache.dense {
298                return dense.clone();
299            }
300        }
301
302        // Need to compute from sparse
303        let sparse = self.get_sparse();
304        let mut dense = vec![0u8; self.size];
305        for &idx in &sparse {
306            dense[idx as usize] = 1;
307        }
308
309        let mut cache = self.cache.borrow_mut();
310        cache.dense = Some(dense.clone());
311        dense
312    }
313
314    /// Gets a reference to the dense representation, computing if necessary.
315    pub fn with_dense<F, R>(&self, f: F) -> R
316    where
317        F: FnOnce(&SdrDense) -> R,
318    {
319        // Ensure dense is computed
320        {
321            let cache = self.cache.borrow();
322            if cache.dense.is_some() {
323                return f(cache.dense.as_ref().unwrap());
324            }
325        }
326
327        let _ = self.get_dense();
328        let cache = self.cache.borrow();
329        f(cache.dense.as_ref().unwrap())
330    }
331
332    // ========================================================================
333    // Sparse format operations
334    // ========================================================================
335
336    /// Sets the SDR value from sparse indices.
337    ///
338    /// # Arguments
339    ///
340    /// * `indices` - Sorted slice of indices of active bits
341    ///
342    /// # Errors
343    ///
344    /// Returns an error if indices are not sorted, contain duplicates, or are out of bounds.
345    pub fn set_sparse(&mut self, indices: &[ElemSparse]) -> Result<()> {
346        // Validate indices
347        self.validate_sparse(indices)?;
348
349        let mut cache = self.cache.borrow_mut();
350        cache.sparse = Some(indices.to_vec());
351        cache.dense = None;
352        cache.coordinates = None;
353        drop(cache);
354
355        self.do_callbacks();
356        Ok(())
357    }
358
359    /// Sets the SDR value from sparse indices, consuming to avoid copying.
360    pub fn set_sparse_owned(&mut self, indices: SdrSparse) -> Result<()> {
361        self.validate_sparse(&indices)?;
362
363        let mut cache = self.cache.borrow_mut();
364        cache.sparse = Some(indices);
365        cache.dense = None;
366        cache.coordinates = None;
367        drop(cache);
368
369        self.do_callbacks();
370        Ok(())
371    }
372
373    /// Sets sparse indices without validation (for internal use).
374    pub(crate) fn set_sparse_unchecked(&mut self, indices: SdrSparse) {
375        let mut cache = self.cache.borrow_mut();
376        cache.sparse = Some(indices);
377        cache.dense = None;
378        cache.coordinates = None;
379        drop(cache);
380
381        self.do_callbacks();
382    }
383
384    /// Validates sparse indices.
385    fn validate_sparse(&self, indices: &[ElemSparse]) -> Result<()> {
386        if indices.is_empty() {
387            return Ok(());
388        }
389
390        // Check bounds and ordering
391        let mut prev = indices[0];
392        if prev as usize >= self.size {
393            return Err(MokoshError::IndexOutOfBounds {
394                index: prev as usize,
395                size: self.size,
396            });
397        }
398
399        for &idx in &indices[1..] {
400            if idx <= prev {
401                return Err(MokoshError::InvalidSdrData(
402                    "Sparse indices must be sorted and unique".to_string(),
403                ));
404            }
405            if idx as usize >= self.size {
406                return Err(MokoshError::IndexOutOfBounds {
407                    index: idx as usize,
408                    size: self.size,
409                });
410            }
411            prev = idx;
412        }
413
414        Ok(())
415    }
416
417    /// Gets the sparse representation of the SDR.
418    #[must_use]
419    pub fn get_sparse(&self) -> SdrSparse {
420        {
421            let cache = self.cache.borrow();
422            if let Some(ref sparse) = cache.sparse {
423                return sparse.clone();
424            }
425        }
426
427        // Compute from dense or coordinates
428        let sparse = {
429            let cache = self.cache.borrow();
430            if let Some(ref dense) = cache.dense {
431                dense
432                    .iter()
433                    .enumerate()
434                    .filter(|(_, &v)| v != 0)
435                    .map(|(i, _)| i as ElemSparse)
436                    .collect()
437            } else if let Some(ref coords) = cache.coordinates {
438                self.coordinates_to_sparse(coords)
439            } else {
440                // No data set, return empty
441                Vec::new()
442            }
443        };
444
445        let mut cache = self.cache.borrow_mut();
446        cache.sparse = Some(sparse.clone());
447        sparse
448    }
449
450    /// Gets a reference to the sparse representation.
451    pub fn with_sparse<F, R>(&self, f: F) -> R
452    where
453        F: FnOnce(&SdrSparse) -> R,
454    {
455        {
456            let cache = self.cache.borrow();
457            if cache.sparse.is_some() {
458                return f(cache.sparse.as_ref().unwrap());
459            }
460        }
461
462        let _ = self.get_sparse();
463        let cache = self.cache.borrow();
464        f(cache.sparse.as_ref().unwrap())
465    }
466
467    // ========================================================================
468    // Coordinate format operations
469    // ========================================================================
470
471    /// Sets the SDR value from coordinates.
472    ///
473    /// # Arguments
474    ///
475    /// * `coordinates` - A vector of coordinate vectors, one per dimension
476    ///
477    /// # Errors
478    ///
479    /// Returns an error if coordinates are invalid.
480    pub fn set_coordinates(&mut self, coordinates: &SdrCoordinate) -> Result<()> {
481        if coordinates.len() != self.dimensions.len() {
482            return Err(MokoshError::InvalidDimensions(format!(
483                "Expected {} dimensions, got {}",
484                self.dimensions.len(),
485                coordinates.len()
486            )));
487        }
488
489        // Validate that all inner vectors have the same length
490        if !coordinates.is_empty() {
491            let len = coordinates[0].len();
492            for (i, coord) in coordinates.iter().enumerate() {
493                if coord.len() != len {
494                    return Err(MokoshError::InvalidSdrData(format!(
495                        "Coordinate dimension {} has length {}, expected {}",
496                        i,
497                        coord.len(),
498                        len
499                    )));
500                }
501            }
502        }
503
504        // Validate bounds
505        for (dim_idx, (coords, &dim_size)) in coordinates.iter().zip(&self.dimensions).enumerate() {
506            for &c in coords {
507                if c >= dim_size {
508                    return Err(MokoshError::IndexOutOfBounds {
509                        index: c as usize,
510                        size: dim_size as usize,
511                    });
512                }
513            }
514        }
515
516        let mut cache = self.cache.borrow_mut();
517        cache.coordinates = Some(coordinates.clone());
518        cache.dense = None;
519        cache.sparse = None;
520        drop(cache);
521
522        self.do_callbacks();
523        Ok(())
524    }
525
526    /// Gets the coordinate representation of the SDR.
527    #[must_use]
528    pub fn get_coordinates(&self) -> SdrCoordinate {
529        {
530            let cache = self.cache.borrow();
531            if let Some(ref coords) = cache.coordinates {
532                return coords.clone();
533            }
534        }
535
536        // Compute from sparse
537        let sparse = self.get_sparse();
538        let coords = self.sparse_to_coordinates(&sparse);
539
540        let mut cache = self.cache.borrow_mut();
541        cache.coordinates = Some(coords.clone());
542        coords
543    }
544
545    /// Converts flat indices to coordinates.
546    fn sparse_to_coordinates(&self, sparse: &[ElemSparse]) -> SdrCoordinate {
547        let num_dims = self.dimensions.len();
548        let mut coordinates: SdrCoordinate = vec![Vec::with_capacity(sparse.len()); num_dims];
549
550        for &flat_idx in sparse {
551            let mut idx = flat_idx as usize;
552            for dim in (0..num_dims).rev() {
553                let dim_size = self.dimensions[dim] as usize;
554                coordinates[dim].push((idx % dim_size) as UInt);
555                idx /= dim_size;
556            }
557        }
558
559        // Reverse each dimension's coordinates since we computed them backwards
560        for coords in &mut coordinates {
561            coords.reverse();
562            // Re-reverse to maintain original order
563        }
564
565        // Actually, let me reconsider - the coordinates should be in the same
566        // order as the sparse indices. Let me fix this.
567        let mut coordinates: SdrCoordinate = vec![Vec::with_capacity(sparse.len()); num_dims];
568
569        for &flat_idx in sparse {
570            let mut idx = flat_idx as usize;
571            let mut temp_coords = vec![0u32; num_dims];
572
573            for dim in (0..num_dims).rev() {
574                let dim_size = self.dimensions[dim] as usize;
575                temp_coords[dim] = (idx % dim_size) as UInt;
576                idx /= dim_size;
577            }
578
579            for (dim, &coord) in temp_coords.iter().enumerate() {
580                coordinates[dim].push(coord);
581            }
582        }
583
584        coordinates
585    }
586
587    /// Converts coordinates to flat indices.
588    fn coordinates_to_sparse(&self, coordinates: &SdrCoordinate) -> SdrSparse {
589        if coordinates.is_empty() || coordinates[0].is_empty() {
590            return Vec::new();
591        }
592
593        let num_points = coordinates[0].len();
594        let mut sparse = Vec::with_capacity(num_points);
595
596        for i in 0..num_points {
597            let mut flat_idx: usize = 0;
598            let mut multiplier: usize = 1;
599
600            for dim in (0..self.dimensions.len()).rev() {
601                flat_idx += coordinates[dim][i] as usize * multiplier;
602                multiplier *= self.dimensions[dim] as usize;
603            }
604
605            sparse.push(flat_idx as ElemSparse);
606        }
607
608        // Sort and deduplicate
609        sparse.sort_unstable();
610        sparse.dedup();
611        sparse
612    }
613
614    // ========================================================================
615    // Value queries
616    // ========================================================================
617
618    /// Returns the value at the given coordinates.
619    #[must_use]
620    pub fn at(&self, coordinates: &[UInt]) -> bool {
621        assert_eq!(
622            coordinates.len(),
623            self.dimensions.len(),
624            "Coordinate dimensions mismatch"
625        );
626
627        let flat_idx = self.coordinates_to_flat(coordinates);
628        self.with_dense(|dense| dense[flat_idx] != 0)
629    }
630
631    /// Converts coordinates to a flat index.
632    fn coordinates_to_flat(&self, coordinates: &[UInt]) -> usize {
633        let mut flat_idx: usize = 0;
634        let mut multiplier: usize = 1;
635
636        for dim in (0..self.dimensions.len()).rev() {
637            flat_idx += coordinates[dim] as usize * multiplier;
638            multiplier *= self.dimensions[dim] as usize;
639        }
640
641        flat_idx
642    }
643
644    /// Returns the number of active (true) bits.
645    #[must_use]
646    pub fn get_sum(&self) -> usize {
647        self.with_sparse(Vec::len)
648    }
649
650    /// Returns the sparsity (fraction of active bits).
651    #[must_use]
652    pub fn get_sparsity(&self) -> Real {
653        if self.size == 0 {
654            return 0.0;
655        }
656        self.get_sum() as Real / self.size as Real
657    }
658
659    /// Returns the number of bits that are active in both SDRs.
660    #[must_use]
661    pub fn get_overlap(&self, other: &Sdr) -> usize {
662        let a = self.get_sparse();
663        let b = other.get_sparse();
664
665        // Use SIMD-accelerated sorted overlap
666        simd::sorted_overlap(&a, &b)
667    }
668
669    // ========================================================================
670    // SDR operations
671    // ========================================================================
672
673    /// Copies the value from another SDR.
674    ///
675    /// # Errors
676    ///
677    /// Returns an error if dimensions don't match.
678    pub fn set_sdr(&mut self, other: &Sdr) -> Result<()> {
679        if self.dimensions != other.dimensions {
680            return Err(MokoshError::DimensionMismatch {
681                expected: self.dimensions.clone(),
682                actual: other.dimensions.clone(),
683            });
684        }
685
686        let sparse = other.get_sparse();
687        self.set_sparse_owned(sparse)
688    }
689
690    /// Randomizes the SDR with the given sparsity.
691    ///
692    /// # Arguments
693    ///
694    /// * `sparsity` - Fraction of bits to set active (0.0 to 1.0)
695    /// * `rng` - Random number generator
696    pub fn randomize(&mut self, sparsity: Real, rng: &mut Random) {
697        let num_active = ((self.size as Real) * sparsity).round() as usize;
698
699        if num_active == 0 {
700            self.zero();
701            return;
702        }
703
704        if num_active >= self.size {
705            let mut cache = self.cache.borrow_mut();
706            cache.dense = Some(vec![1; self.size]);
707            cache.sparse = Some((0..self.size as ElemSparse).collect());
708            cache.coordinates = None;
709            drop(cache);
710            self.do_callbacks();
711            return;
712        }
713
714        // Generate random indices
715        let indices = rng.sample((0..self.size as ElemSparse).collect(), num_active);
716        let mut sparse: SdrSparse = indices;
717        sparse.sort_unstable();
718
719        let mut cache = self.cache.borrow_mut();
720        cache.sparse = Some(sparse);
721        cache.dense = None;
722        cache.coordinates = None;
723        drop(cache);
724
725        self.do_callbacks();
726    }
727
728    /// Adds noise to the SDR by flipping a fraction of bits.
729    ///
730    /// # Arguments
731    ///
732    /// * `fraction_noise` - Fraction of active bits to move (0.0 to 1.0)
733    /// * `rng` - Random number generator
734    pub fn add_noise(&mut self, fraction_noise: Real, rng: &mut Random) {
735        let sparse = self.get_sparse();
736        let num_active = sparse.len();
737
738        if num_active == 0 || fraction_noise <= 0.0 {
739            return;
740        }
741
742        let num_to_flip = ((num_active as Real) * fraction_noise).round() as usize;
743        if num_to_flip == 0 {
744            return;
745        }
746
747        // Select bits to turn off
748        let turn_off = rng.sample(sparse.clone(), num_to_flip);
749
750        // Find inactive bits to turn on
751        let active_set: std::collections::HashSet<_> = sparse.iter().copied().collect();
752        let inactive: Vec<ElemSparse> = (0..self.size as ElemSparse)
753            .filter(|&i| !active_set.contains(&i))
754            .collect();
755
756        let turn_on = rng.sample(inactive, num_to_flip);
757
758        // Create new sparse representation
759        let turn_off_set: std::collections::HashSet<_> = turn_off.iter().copied().collect();
760        let mut new_sparse: SdrSparse = sparse
761            .into_iter()
762            .filter(|&i| !turn_off_set.contains(&i))
763            .chain(turn_on)
764            .collect();
765        new_sparse.sort_unstable();
766
767        let mut cache = self.cache.borrow_mut();
768        cache.sparse = Some(new_sparse);
769        cache.dense = None;
770        cache.coordinates = None;
771        drop(cache);
772
773        self.do_callbacks();
774    }
775
776    /// Computes the intersection of two SDRs into this SDR.
777    ///
778    /// # Errors
779    ///
780    /// Returns an error if dimensions don't match.
781    pub fn intersection(&mut self, a: &Sdr, b: &Sdr) -> Result<()> {
782        if a.dimensions != b.dimensions {
783            return Err(MokoshError::DimensionMismatch {
784                expected: a.dimensions.clone(),
785                actual: b.dimensions.clone(),
786            });
787        }
788
789        if self.dimensions != a.dimensions {
790            return Err(MokoshError::DimensionMismatch {
791                expected: self.dimensions.clone(),
792                actual: a.dimensions.clone(),
793            });
794        }
795
796        let sparse_a = a.get_sparse();
797        let sparse_b = b.get_sparse();
798
799        // Set intersection of sorted vectors
800        let mut result = Vec::new();
801        let mut i = 0;
802        let mut j = 0;
803
804        while i < sparse_a.len() && j < sparse_b.len() {
805            match sparse_a[i].cmp(&sparse_b[j]) {
806                std::cmp::Ordering::Less => i += 1,
807                std::cmp::Ordering::Greater => j += 1,
808                std::cmp::Ordering::Equal => {
809                    result.push(sparse_a[i]);
810                    i += 1;
811                    j += 1;
812                }
813            }
814        }
815
816        self.set_sparse_unchecked(result);
817        Ok(())
818    }
819
820    /// Computes the union of two SDRs into this SDR.
821    ///
822    /// # Errors
823    ///
824    /// Returns an error if dimensions don't match.
825    pub fn set_union(&mut self, a: &Sdr, b: &Sdr) -> Result<()> {
826        if a.dimensions != b.dimensions {
827            return Err(MokoshError::DimensionMismatch {
828                expected: a.dimensions.clone(),
829                actual: b.dimensions.clone(),
830            });
831        }
832
833        if self.dimensions != a.dimensions {
834            return Err(MokoshError::DimensionMismatch {
835                expected: self.dimensions.clone(),
836                actual: a.dimensions.clone(),
837            });
838        }
839
840        let sparse_a = a.get_sparse();
841        let sparse_b = b.get_sparse();
842
843        // Set union of sorted vectors
844        let mut result = Vec::with_capacity(sparse_a.len() + sparse_b.len());
845        let mut i = 0;
846        let mut j = 0;
847
848        while i < sparse_a.len() && j < sparse_b.len() {
849            match sparse_a[i].cmp(&sparse_b[j]) {
850                std::cmp::Ordering::Less => {
851                    result.push(sparse_a[i]);
852                    i += 1;
853                }
854                std::cmp::Ordering::Greater => {
855                    result.push(sparse_b[j]);
856                    j += 1;
857                }
858                std::cmp::Ordering::Equal => {
859                    result.push(sparse_a[i]);
860                    i += 1;
861                    j += 1;
862                }
863            }
864        }
865
866        result.extend(&sparse_a[i..]);
867        result.extend(&sparse_b[j..]);
868
869        self.set_sparse_unchecked(result);
870        Ok(())
871    }
872
873    /// Concatenates SDRs along an axis into this SDR.
874    ///
875    /// # Errors
876    ///
877    /// Returns an error if dimensions don't match.
878    pub fn concatenate(&mut self, inputs: &[&Sdr], axis: usize) -> Result<()> {
879        if inputs.is_empty() {
880            return Err(MokoshError::InvalidParameter {
881                name: "inputs",
882                message: "Cannot concatenate empty list".to_string(),
883            });
884        }
885
886        // Verify all inputs have compatible dimensions
887        let num_dims = inputs[0].num_dimensions();
888        for (i, input) in inputs.iter().enumerate() {
889            if input.num_dimensions() != num_dims {
890                return Err(MokoshError::InvalidDimensions(format!(
891                    "Input {} has {} dimensions, expected {}",
892                    i,
893                    input.num_dimensions(),
894                    num_dims
895                )));
896            }
897        }
898
899        // Compute output sparse representation
900        let mut result = Vec::new();
901        let mut offset: usize = 0;
902
903        for input in inputs {
904            let sparse = input.get_sparse();
905            for &idx in &sparse {
906                result.push((idx as usize + offset) as ElemSparse);
907            }
908            offset += input.size();
909        }
910
911        // Verify result fits in this SDR
912        if offset != self.size {
913            return Err(MokoshError::DimensionMismatch {
914                expected: vec![self.size as u32],
915                actual: vec![offset as u32],
916            });
917        }
918
919        self.set_sparse_unchecked(result);
920        Ok(())
921    }
922
923    // ========================================================================
924    // Callbacks
925    // ========================================================================
926
927    /// Adds a callback that is called whenever the SDR value changes.
928    ///
929    /// Returns a handle that can be used to remove the callback.
930    pub fn add_callback(&self, callback: SdrCallback) -> usize {
931        let mut callbacks = self.callbacks.borrow_mut();
932        let handle = callbacks.len();
933        callbacks.push(Some(callback));
934        handle
935    }
936
937    /// Removes a callback by its handle.
938    pub fn remove_callback(&self, handle: usize) -> Result<()> {
939        let mut callbacks = self.callbacks.borrow_mut();
940        if handle >= callbacks.len() || callbacks[handle].is_none() {
941            return Err(MokoshError::InvalidParameter {
942                name: "handle",
943                message: format!("Invalid callback handle: {}", handle),
944            });
945        }
946        callbacks[handle] = None;
947        Ok(())
948    }
949
950    /// Adds a callback that is called when the SDR is destroyed.
951    pub fn add_destroy_callback(&self, callback: SdrCallback) -> usize {
952        let mut callbacks = self.destroy_callbacks.borrow_mut();
953        let handle = callbacks.len();
954        callbacks.push(Some(callback));
955        handle
956    }
957
958    /// Removes a destroy callback by its handle.
959    pub fn remove_destroy_callback(&self, handle: usize) -> Result<()> {
960        let mut callbacks = self.destroy_callbacks.borrow_mut();
961        if handle >= callbacks.len() || callbacks[handle].is_none() {
962            return Err(MokoshError::InvalidParameter {
963                name: "handle",
964                message: format!("Invalid destroy callback handle: {}", handle),
965            });
966        }
967        callbacks[handle] = None;
968        Ok(())
969    }
970}
971
972impl Clone for Sdr {
973    fn clone(&self) -> Self {
974        let mut new_sdr = Self::new(&self.dimensions);
975
976        // Copy the most efficient representation available
977        let cache = self.cache.borrow();
978        if let Some(ref sparse) = cache.sparse {
979            new_sdr.cache.borrow_mut().sparse = Some(sparse.clone());
980        } else if let Some(ref dense) = cache.dense {
981            new_sdr.cache.borrow_mut().dense = Some(dense.clone());
982        } else if let Some(ref coords) = cache.coordinates {
983            new_sdr.cache.borrow_mut().coordinates = Some(coords.clone());
984        }
985
986        new_sdr
987    }
988}
989
990impl PartialEq for Sdr {
991    fn eq(&self, other: &Self) -> bool {
992        if self.dimensions != other.dimensions {
993            return false;
994        }
995        self.get_sparse() == other.get_sparse()
996    }
997}
998
999impl Eq for Sdr {}
1000
1001impl fmt::Debug for Sdr {
1002    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1003        let sparse = self.get_sparse();
1004        write!(f, "SDR({:?}) {:?}", self.dimensions, sparse)
1005    }
1006}
1007
1008impl fmt::Display for Sdr {
1009    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1010        write!(f, "SDR( ")?;
1011        for (i, dim) in self.dimensions.iter().enumerate() {
1012            write!(f, "{}", dim)?;
1013            if i + 1 != self.dimensions.len() {
1014                write!(f, ", ")?;
1015            }
1016        }
1017        write!(f, " ) ")?;
1018
1019        let sparse = self.get_sparse();
1020        for (i, &idx) in sparse.iter().enumerate() {
1021            write!(f, "{}", idx)?;
1022            if i + 1 != sparse.len() {
1023                write!(f, ", ")?;
1024            }
1025        }
1026        Ok(())
1027    }
1028}
1029
1030impl Drop for Sdr {
1031    fn drop(&mut self) {
1032        let callbacks = self.destroy_callbacks.borrow();
1033        for callback in callbacks.iter().flatten() {
1034            callback();
1035        }
1036    }
1037}
1038
1039impl Default for Sdr {
1040    fn default() -> Self {
1041        Self::new(&[0])
1042    }
1043}
1044
1045#[cfg(test)]
1046mod tests {
1047    use super::*;
1048
1049    #[test]
1050    fn test_constructor() {
1051        let sdr = Sdr::new(&[3]);
1052        assert_eq!(sdr.size(), 3);
1053        assert_eq!(sdr.dimensions(), &[3]);
1054
1055        let sdr2 = Sdr::new(&[3, 4, 5]);
1056        assert_eq!(sdr2.size(), 60);
1057        assert_eq!(sdr2.dimensions(), &[3, 4, 5]);
1058    }
1059
1060    #[test]
1061    fn test_empty_placeholder() {
1062        let sdr = Sdr::new(&[0]);
1063        assert_eq!(sdr.size(), 0);
1064    }
1065
1066    #[test]
1067    fn test_zero() {
1068        let mut sdr = Sdr::new(&[4, 4]);
1069        sdr.set_dense(&vec![1; 16]).unwrap();
1070        sdr.zero();
1071        assert_eq!(sdr.get_sum(), 0);
1072    }
1073
1074    #[test]
1075    fn test_dense_sparse_conversion() {
1076        let mut sdr = Sdr::new(&[9]);
1077        sdr.set_dense(&[0, 1, 0, 0, 1, 0, 0, 0, 1]).unwrap();
1078        assert_eq!(sdr.get_sparse(), vec![1, 4, 8]);
1079
1080        sdr.set_sparse(&[1, 4, 8]).unwrap();
1081        assert_eq!(sdr.get_dense(), vec![0, 1, 0, 0, 1, 0, 0, 0, 1]);
1082    }
1083
1084    #[test]
1085    fn test_coordinates() {
1086        let mut sdr = Sdr::new(&[3, 3]);
1087        sdr.set_coordinates(&vec![vec![0, 1, 2], vec![1, 1, 2]]).unwrap();
1088        assert_eq!(sdr.get_sparse(), vec![1, 4, 8]);
1089
1090        sdr.set_sparse(&[1, 4, 8]).unwrap();
1091        let coords = sdr.get_coordinates();
1092        assert_eq!(coords, vec![vec![0, 1, 2], vec![1, 1, 2]]);
1093    }
1094
1095    #[test]
1096    fn test_at() {
1097        let mut sdr = Sdr::new(&[3, 3]);
1098        sdr.set_sparse(&[4, 5, 8]).unwrap();
1099        assert!(sdr.at(&[1, 1]));
1100        assert!(sdr.at(&[1, 2]));
1101        assert!(sdr.at(&[2, 2]));
1102        assert!(!sdr.at(&[0, 0]));
1103    }
1104
1105    #[test]
1106    fn test_sum_sparsity() {
1107        let mut sdr = Sdr::new(&[100]);
1108        sdr.set_sparse(&[1, 2, 3, 4, 5]).unwrap();
1109        assert_eq!(sdr.get_sum(), 5);
1110        assert!((sdr.get_sparsity() - 0.05).abs() < 0.001);
1111    }
1112
1113    #[test]
1114    fn test_overlap() {
1115        let mut a = Sdr::new(&[9]);
1116        let mut b = Sdr::new(&[9]);
1117        a.set_sparse(&[1, 2, 3, 4]).unwrap();
1118        b.set_sparse(&[2, 3, 4, 5]).unwrap();
1119        assert_eq!(a.get_overlap(&b), 3);
1120    }
1121
1122    #[test]
1123    fn test_intersection() {
1124        let mut a = Sdr::new(&[10]);
1125        let mut b = Sdr::new(&[10]);
1126        let mut c = Sdr::new(&[10]);
1127
1128        a.set_sparse(&[0, 1, 2, 3]).unwrap();
1129        b.set_sparse(&[2, 3, 4, 5]).unwrap();
1130        c.intersection(&a, &b).unwrap();
1131
1132        assert_eq!(c.get_sparse(), vec![2, 3]);
1133    }
1134
1135    #[test]
1136    fn test_union() {
1137        let mut a = Sdr::new(&[10]);
1138        let mut b = Sdr::new(&[10]);
1139        let mut c = Sdr::new(&[10]);
1140
1141        a.set_sparse(&[0, 1, 2, 3]).unwrap();
1142        b.set_sparse(&[2, 3, 4, 5]).unwrap();
1143        c.set_union(&a, &b).unwrap();
1144
1145        assert_eq!(c.get_sparse(), vec![0, 1, 2, 3, 4, 5]);
1146    }
1147
1148    #[test]
1149    fn test_concatenate() {
1150        let mut a = Sdr::new(&[10]);
1151        let mut b = Sdr::new(&[10]);
1152        let mut c = Sdr::new(&[20]);
1153
1154        a.set_sparse(&[0, 1, 2]).unwrap();
1155        b.set_sparse(&[0, 1, 2]).unwrap();
1156        c.concatenate(&[&a, &b], 0).unwrap();
1157
1158        assert_eq!(c.get_sparse(), vec![0, 1, 2, 10, 11, 12]);
1159    }
1160
1161    #[test]
1162    fn test_equality() {
1163        let mut a = Sdr::new(&[10]);
1164        let mut b = Sdr::new(&[10]);
1165
1166        a.set_sparse(&[1, 2, 3]).unwrap();
1167        b.set_sparse(&[1, 2, 3]).unwrap();
1168        assert_eq!(a, b);
1169
1170        b.set_sparse(&[1, 2, 4]).unwrap();
1171        assert_ne!(a, b);
1172    }
1173
1174    #[test]
1175    fn test_reshape() {
1176        let mut sdr = Sdr::new(&[3, 4, 5]);
1177        sdr.set_sparse(&[0, 5, 10]).unwrap();
1178
1179        sdr.reshape(&[5, 12]).unwrap();
1180        assert_eq!(sdr.dimensions(), &[5, 12]);
1181        assert_eq!(sdr.get_sparse(), vec![0, 5, 10]);
1182    }
1183
1184    #[test]
1185    fn test_display() {
1186        let mut sdr = Sdr::new(&[3, 3]);
1187        sdr.set_sparse(&[1, 4, 8]).unwrap();
1188        let s = format!("{}", sdr);
1189        assert!(s.contains("SDR( 3, 3 )"));
1190        assert!(s.contains("1, 4, 8"));
1191    }
1192
1193    #[test]
1194    fn test_clone() {
1195        let mut sdr = Sdr::new(&[10]);
1196        sdr.set_sparse(&[1, 2, 3]).unwrap();
1197
1198        let cloned = sdr.clone();
1199        assert_eq!(sdr, cloned);
1200
1201        // Verify deep copy
1202        sdr.set_sparse(&[4, 5, 6]).unwrap();
1203        assert_ne!(sdr, cloned);
1204    }
1205}