Skip to main content

embeddenator_vsa/
resonator.rs

1//! Resonator Networks for Learned Codebooks
2//!
3//! This module implements resonator networks - a neural-inspired architecture for
4//! iterative factorization and semantic variable inference in Vector Symbolic
5//! Architectures (VSA).
6//!
7//! # Architecture
8//!
9//! Resonator networks solve the factorization problem: given a composite vector
10//! that is the binding of multiple factors, recover the original factors.
11//!
12//! ```text
13//! Input: x = f₁ ⊙ f₂ ⊙ f₃  (bound composite)
14//!
15//! Resonator Loop:
16//!   estimate_f₁ = unbind(x, estimate_f₂, estimate_f₃)
17//!   estimate_f₁ = cleanup(estimate_f₁, codebook_1)
18//!   estimate_f₂ = unbind(x, estimate_f₁, estimate_f₃)
19//!   estimate_f₂ = cleanup(estimate_f₂, codebook_2)
20//!   estimate_f₃ = unbind(x, estimate_f₁, estimate_f₂)
21//!   estimate_f₃ = cleanup(estimate_f₃, codebook_3)
22//!
23//! Repeat until convergence or max iterations
24//! ```
25//!
26//! # Gradient Learning
27//!
28//! The codebook vectors can be optimized through gradient descent on a
29//! reconstruction loss, enabling learned representations that better capture
30//! the structure of training data.
31//!
32//! # Example
33//!
34//! ```rust,ignore
35//! use embeddenator_vsa::{Resonator, ResonatorConfig, Codebook};
36//!
37//! let config = ResonatorConfig::default();
38//! let mut resonator = Resonator::new(config);
39//!
40//! // Add codebooks for each factor type
41//! resonator.add_codebook("type", type_codebook);
42//! resonator.add_codebook("position", position_codebook);
43//!
44//! // Factorize a composite vector
45//! let result = resonator.factorize(&composite_vec, 100)?;
46//! println!("Type: {:?}", result.factors.get("type"));
47//! println!("Position: {:?}", result.factors.get("position"));
48//! ```
49
50use std::collections::HashMap;
51
52use crate::codebook::Codebook;
53use crate::vsa::{SparseVec, DIM};
54
55/// Configuration for resonator network
56#[derive(Clone, Debug)]
57pub struct ResonatorConfig {
58    /// Maximum iterations before giving up
59    pub max_iterations: usize,
60    /// Convergence threshold (cosine similarity to previous estimate)
61    pub convergence_threshold: f64,
62    /// Learning rate for gradient updates during training
63    pub learning_rate: f64,
64    /// Momentum coefficient for gradient updates
65    pub momentum: f64,
66    /// Weight decay for regularization
67    pub weight_decay: f64,
68    /// Temperature for softmax cleanup (lower = sharper)
69    pub temperature: f64,
70    /// Whether to use soft cleanup (weighted average) vs hard (argmax)
71    pub soft_cleanup: bool,
72    /// Number of top candidates to consider in soft cleanup
73    pub soft_cleanup_top_k: usize,
74}
75
76impl Default for ResonatorConfig {
77    fn default() -> Self {
78        Self {
79            max_iterations: 50,
80            convergence_threshold: 0.99,
81            learning_rate: 0.01,
82            momentum: 0.9,
83            weight_decay: 1e-5,
84            temperature: 0.1,
85            soft_cleanup: true,
86            soft_cleanup_top_k: 8,
87        }
88    }
89}
90
91/// Result of resonator factorization
92#[derive(Clone, Debug)]
93pub struct FactorizationResult {
94    /// Recovered factors by codebook name
95    pub factors: HashMap<String, RecoveredFactor>,
96    /// Number of iterations used
97    pub iterations: usize,
98    /// Whether convergence was achieved
99    pub converged: bool,
100    /// Final reconstruction quality (cosine similarity to input)
101    pub reconstruction_quality: f64,
102    /// Per-iteration convergence metrics
103    pub convergence_history: Vec<f64>,
104}
105
106/// A recovered factor from factorization
107#[derive(Clone, Debug)]
108pub struct RecoveredFactor {
109    /// Best matching basis vector ID
110    pub best_match_id: u32,
111    /// Best matching basis vector (if available)
112    pub best_match: Option<SparseVec>,
113    /// Similarity to best match
114    pub confidence: f64,
115    /// The estimated factor vector
116    pub estimate: SparseVec,
117    /// Top-k candidate matches with similarities
118    pub candidates: Vec<(u32, f64)>,
119}
120
121/// Gradient state for a single codebook during training
122#[derive(Clone, Debug)]
123struct CodebookGradient {
124    /// Accumulated gradients for each basis vector (by ID)
125    gradients: HashMap<u32, Vec<f64>>,
126    /// Momentum terms for each basis vector
127    momentum: HashMap<u32, Vec<f64>>,
128}
129
130impl CodebookGradient {
131    fn new() -> Self {
132        Self {
133            gradients: HashMap::new(),
134            momentum: HashMap::new(),
135        }
136    }
137
138    fn zero_gradients(&mut self) {
139        for grad in self.gradients.values_mut() {
140            grad.fill(0.0);
141        }
142    }
143}
144
145/// Resonator network for iterative factorization and learning
146pub struct Resonator {
147    /// Configuration
148    config: ResonatorConfig,
149    /// Named codebooks for each factor type
150    codebooks: HashMap<String, Codebook>,
151    /// Order of factor estimation (determines unbinding order)
152    factor_order: Vec<String>,
153    /// Gradient accumulators for training
154    gradients: HashMap<String, CodebookGradient>,
155    /// Training statistics
156    stats: ResonatorStats,
157}
158
159/// Statistics from resonator training
160#[derive(Clone, Debug, Default)]
161pub struct ResonatorStats {
162    /// Total factorization attempts
163    pub total_factorizations: u64,
164    /// Successful convergences
165    pub converged_count: u64,
166    /// Average iterations to converge
167    pub avg_iterations: f64,
168    /// Average reconstruction quality
169    pub avg_reconstruction_quality: f64,
170    /// Total training steps
171    pub training_steps: u64,
172    /// Current average loss
173    pub current_loss: f64,
174}
175
176impl Resonator {
177    /// Create a new resonator network
178    pub fn new(config: ResonatorConfig) -> Self {
179        Self {
180            config,
181            codebooks: HashMap::new(),
182            factor_order: Vec::new(),
183            gradients: HashMap::new(),
184            stats: ResonatorStats::default(),
185        }
186    }
187
188    /// Add a codebook for a factor type
189    pub fn add_codebook(&mut self, name: &str, codebook: Codebook) {
190        self.codebooks.insert(name.to_string(), codebook);
191        self.factor_order.push(name.to_string());
192        self.gradients
193            .insert(name.to_string(), CodebookGradient::new());
194    }
195
196    /// Get a codebook by name
197    pub fn get_codebook(&self, name: &str) -> Option<&Codebook> {
198        self.codebooks.get(name)
199    }
200
201    /// Get mutable codebook by name
202    pub fn get_codebook_mut(&mut self, name: &str) -> Option<&mut Codebook> {
203        self.codebooks.get_mut(name)
204    }
205
206    /// Get current statistics
207    pub fn stats(&self) -> &ResonatorStats {
208        &self.stats
209    }
210
211    /// Factorize a composite vector into its constituent factors
212    ///
213    /// The composite is assumed to be the binding (⊙) of factors from
214    /// each codebook. The resonator iteratively estimates each factor.
215    pub fn factorize(&mut self, composite: &SparseVec) -> FactorizationResult {
216        self.factorize_with_iterations(composite, self.config.max_iterations)
217    }
218
219    /// Factorize with a specific maximum iteration count
220    pub fn factorize_with_iterations(
221        &mut self,
222        composite: &SparseVec,
223        max_iterations: usize,
224    ) -> FactorizationResult {
225        if self.factor_order.is_empty() {
226            return FactorizationResult {
227                factors: HashMap::new(),
228                iterations: 0,
229                converged: true,
230                reconstruction_quality: 0.0,
231                convergence_history: Vec::new(),
232            };
233        }
234
235        // Initialize estimates randomly or from partial unbinding
236        let mut estimates: HashMap<String, SparseVec> = self
237            .factor_order
238            .iter()
239            .map(|name| (name.clone(), SparseVec::random()))
240            .collect();
241
242        let mut convergence_history = Vec::new();
243        let mut prev_estimates = estimates.clone();
244        let mut converged = false;
245
246        for iteration in 0..max_iterations {
247            // Update each factor estimate
248            for name in &self.factor_order.clone() {
249                // Compute the unbinding: x ⊙ f₂⁻¹ ⊙ f₃⁻¹ ... to estimate f₁
250                let mut unbound = composite.clone();
251                for (other_name, other_estimate) in &estimates {
252                    if other_name != name {
253                        // Unbind by binding with inverse (self-inverse property for sparse ternary)
254                        unbound = unbound.bind(other_estimate);
255                    }
256                }
257
258                // Cleanup: project onto codebook
259                let cleaned = if let Some(codebook) = self.codebooks.get(name) {
260                    self.cleanup(&unbound, codebook)
261                } else {
262                    unbound
263                };
264
265                estimates.insert(name.clone(), cleaned);
266            }
267
268            // Check convergence
269            let mut min_similarity = 1.0f64;
270            for name in &self.factor_order {
271                let curr = estimates.get(name).unwrap();
272                let prev = prev_estimates.get(name).unwrap();
273                let sim = curr.cosine(prev);
274                min_similarity = min_similarity.min(sim);
275            }
276            convergence_history.push(min_similarity);
277
278            if min_similarity >= self.config.convergence_threshold {
279                converged = true;
280                self.stats.converged_count += 1;
281            }
282
283            prev_estimates = estimates.clone();
284
285            if converged {
286                // Compute final metrics
287                let reconstruction = self.reconstruct(&estimates);
288                let quality = reconstruction.cosine(composite);
289
290                self.stats.total_factorizations += 1;
291                self.stats.avg_reconstruction_quality = (self.stats.avg_reconstruction_quality
292                    * (self.stats.total_factorizations - 1) as f64
293                    + quality)
294                    / self.stats.total_factorizations as f64;
295                self.stats.avg_iterations = (self.stats.avg_iterations
296                    * (self.stats.total_factorizations - 1) as f64
297                    + (iteration + 1) as f64)
298                    / self.stats.total_factorizations as f64;
299
300                return self.build_result(
301                    estimates,
302                    iteration + 1,
303                    true,
304                    quality,
305                    convergence_history,
306                );
307            }
308        }
309
310        // Did not converge
311        self.stats.total_factorizations += 1;
312        let reconstruction = self.reconstruct(&estimates);
313        let quality = reconstruction.cosine(composite);
314
315        self.build_result(
316            estimates,
317            max_iterations,
318            false,
319            quality,
320            convergence_history,
321        )
322    }
323
324    /// Cleanup operation: project vector onto codebook
325    fn cleanup(&self, vec: &SparseVec, codebook: &Codebook) -> SparseVec {
326        if codebook.basis_vectors.is_empty() {
327            return vec.clone();
328        }
329
330        if self.config.soft_cleanup {
331            // Soft cleanup: weighted average of top-k matches
332            self.soft_cleanup(vec, codebook)
333        } else {
334            // Hard cleanup: return the best match
335            self.hard_cleanup(vec, codebook)
336        }
337    }
338
339    /// Hard cleanup: return best matching basis vector
340    fn hard_cleanup(&self, vec: &SparseVec, codebook: &Codebook) -> SparseVec {
341        let mut best_sim = f64::NEG_INFINITY;
342        let mut best_vec = vec.clone();
343
344        for basis in &codebook.basis_vectors {
345            let sim = vec.cosine(&basis.vector);
346            if sim > best_sim {
347                best_sim = sim;
348                best_vec = basis.vector.clone();
349            }
350        }
351
352        best_vec
353    }
354
355    /// Soft cleanup: weighted average of top-k matches
356    fn soft_cleanup(&self, vec: &SparseVec, codebook: &Codebook) -> SparseVec {
357        // Compute similarities to all basis vectors
358        let mut similarities: Vec<(usize, f64)> = codebook
359            .basis_vectors
360            .iter()
361            .enumerate()
362            .map(|(i, basis)| (i, vec.cosine(&basis.vector)))
363            .collect();
364
365        // Sort by similarity descending
366        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
367
368        // Take top-k
369        let top_k: Vec<_> = similarities
370            .into_iter()
371            .take(self.config.soft_cleanup_top_k)
372            .collect();
373
374        if top_k.is_empty() {
375            return vec.clone();
376        }
377
378        // Compute softmax weights
379        let max_sim = top_k.first().map(|t| t.1).unwrap_or(0.0);
380        let weights: Vec<f64> = top_k
381            .iter()
382            .map(|(_, sim)| ((sim - max_sim) / self.config.temperature).exp())
383            .collect();
384        let weight_sum: f64 = weights.iter().sum();
385
386        if weight_sum == 0.0 {
387            return codebook.basis_vectors[top_k[0].0].vector.clone();
388        }
389
390        // Weighted bundle of top-k vectors
391        let weighted_vecs: Vec<_> = top_k
392            .iter()
393            .zip(weights.iter())
394            .map(|((idx, _), w)| (codebook.basis_vectors[*idx].vector.clone(), *w / weight_sum))
395            .collect();
396
397        weighted_bundle(&weighted_vecs)
398    }
399
400    /// Reconstruct composite from factors
401    fn reconstruct(&self, factors: &HashMap<String, SparseVec>) -> SparseVec {
402        let mut result = SparseVec::random(); // Start with identity-ish
403        let mut first = true;
404
405        for name in &self.factor_order {
406            if let Some(factor) = factors.get(name) {
407                if first {
408                    result = factor.clone();
409                    first = false;
410                } else {
411                    result = result.bind(factor);
412                }
413            }
414        }
415
416        result
417    }
418
419    /// Build the factorization result
420    fn build_result(
421        &self,
422        estimates: HashMap<String, SparseVec>,
423        iterations: usize,
424        converged: bool,
425        quality: f64,
426        convergence_history: Vec<f64>,
427    ) -> FactorizationResult {
428        let mut factors = HashMap::new();
429
430        for (name, estimate) in estimates {
431            if let Some(codebook) = self.codebooks.get(&name) {
432                // Find best match and candidates
433                let mut candidates: Vec<(u32, f64)> = codebook
434                    .basis_vectors
435                    .iter()
436                    .map(|b| (b.id, estimate.cosine(&b.vector)))
437                    .collect();
438                candidates
439                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
440
441                let best = candidates.first().cloned().unwrap_or((0, 0.0));
442                let best_match = codebook
443                    .basis_vectors
444                    .iter()
445                    .find(|b| b.id == best.0)
446                    .map(|b| b.vector.clone());
447
448                factors.insert(
449                    name,
450                    RecoveredFactor {
451                        best_match_id: best.0,
452                        best_match,
453                        confidence: best.1,
454                        estimate,
455                        candidates: candidates.into_iter().take(10).collect(),
456                    },
457                );
458            } else {
459                factors.insert(
460                    name,
461                    RecoveredFactor {
462                        best_match_id: 0,
463                        best_match: None,
464                        confidence: 0.0,
465                        estimate,
466                        candidates: Vec::new(),
467                    },
468                );
469            }
470        }
471
472        FactorizationResult {
473            factors,
474            iterations,
475            converged,
476            reconstruction_quality: quality,
477            convergence_history,
478        }
479    }
480
481    /// Train the codebook using gradient-based optimization
482    ///
483    /// Takes training pairs of (composite, expected_factors) and optimizes
484    /// the codebook vectors to minimize reconstruction loss.
485    pub fn train(
486        &mut self,
487        training_data: &[TrainingExample],
488        epochs: usize,
489    ) -> Result<TrainingResult, String> {
490        if training_data.is_empty() {
491            return Err("No training data provided".to_string());
492        }
493
494        let mut loss_history = Vec::new();
495        let batch_size = 32.min(training_data.len());
496
497        for _epoch in 0..epochs {
498            let mut epoch_loss = 0.0;
499            let mut batch_count = 0;
500
501            // Process in batches
502            for batch in training_data.chunks(batch_size) {
503                self.zero_gradients();
504
505                let mut batch_loss = 0.0;
506                for example in batch {
507                    // Forward pass: factorize
508                    let result = self.factorize(&example.composite);
509
510                    // Compute loss: negative reconstruction quality + factor matching loss
511                    let recon_loss = 1.0 - result.reconstruction_quality;
512                    let factor_loss = self.compute_factor_loss(&result, &example.expected_factors);
513                    let total_loss = recon_loss + factor_loss;
514                    batch_loss += total_loss;
515
516                    // Backward pass: accumulate gradients
517                    self.backward(&example.composite, &result, &example.expected_factors);
518                }
519
520                // Apply gradients
521                self.apply_gradients(batch.len());
522
523                epoch_loss += batch_loss;
524                batch_count += 1;
525            }
526
527            let avg_loss = epoch_loss / (batch_count * batch_size) as f64;
528            loss_history.push(avg_loss);
529            self.stats.current_loss = avg_loss;
530            self.stats.training_steps += 1;
531        }
532
533        Ok(TrainingResult {
534            final_loss: *loss_history.last().unwrap_or(&0.0),
535            loss_history,
536            epochs_completed: epochs,
537        })
538    }
539
540    /// Zero all gradient accumulators
541    fn zero_gradients(&mut self) {
542        for grad in self.gradients.values_mut() {
543            grad.zero_gradients();
544        }
545    }
546
547    /// Compute factor matching loss
548    fn compute_factor_loss(
549        &self,
550        result: &FactorizationResult,
551        expected: &HashMap<String, u32>,
552    ) -> f64 {
553        let mut loss = 0.0;
554        let mut count = 0;
555
556        for (name, expected_id) in expected {
557            if let Some(factor) = result.factors.get(name) {
558                // Loss: 1 - confidence if wrong match, 0 if correct
559                if factor.best_match_id != *expected_id {
560                    loss += 1.0 - factor.confidence;
561                }
562                count += 1;
563            }
564        }
565
566        if count > 0 {
567            loss / count as f64
568        } else {
569            0.0
570        }
571    }
572
573    /// Backward pass: compute gradients
574    fn backward(
575        &mut self,
576        _composite: &SparseVec,
577        result: &FactorizationResult,
578        expected: &HashMap<String, u32>,
579    ) {
580        // For each factor, compute gradient to move codebook vectors
581        // toward better reconstruction
582        for (name, expected_id) in expected {
583            if let (Some(factor), Some(codebook)) =
584                (result.factors.get(name), self.codebooks.get(name))
585            {
586                if let Some(grad_state) = self.gradients.get_mut(name) {
587                    // Gradient: push expected closer to estimate, push others away
588                    for basis in &codebook.basis_vectors {
589                        let grad = grad_state
590                            .gradients
591                            .entry(basis.id)
592                            .or_insert_with(|| vec![0.0; DIM]);
593
594                        // Compute gradient contribution
595                        let sim = factor.estimate.cosine(&basis.vector);
596
597                        if basis.id == *expected_id {
598                            // Positive gradient: move toward estimate
599                            add_gradient_toward(grad, &factor.estimate, &basis.vector);
600                        } else if sim > 0.5 {
601                            // Negative gradient: move away if too similar
602                            add_gradient_away(grad, &factor.estimate);
603                        }
604                    }
605                }
606            }
607        }
608    }
609
610    /// Apply accumulated gradients to codebook vectors
611    fn apply_gradients(&mut self, batch_size: usize) {
612        let lr = self.config.learning_rate / batch_size as f64;
613        let momentum = self.config.momentum;
614        let weight_decay = self.config.weight_decay;
615
616        for (name, grad_state) in &mut self.gradients {
617            if let Some(codebook) = self.codebooks.get_mut(name) {
618                for basis in &mut codebook.basis_vectors {
619                    if let Some(grad) = grad_state.gradients.get(&basis.id) {
620                        // Get or create momentum buffer
621                        let mom = grad_state
622                            .momentum
623                            .entry(basis.id)
624                            .or_insert_with(|| vec![0.0; DIM]);
625
626                        // Update each dimension
627                        let mut new_pos = Vec::new();
628                        let mut new_neg = Vec::new();
629
630                        for dim in 0..DIM {
631                            // Compute momentum-updated gradient
632                            mom[dim] = momentum * mom[dim] + grad[dim];
633
634                            // Get current value for this dimension
635                            let is_pos = basis.vector.pos.contains(&dim);
636                            let is_neg = basis.vector.neg.contains(&dim);
637                            let current_val = if is_pos {
638                                1.0
639                            } else if is_neg {
640                                -1.0
641                            } else {
642                                0.0
643                            };
644
645                            // Apply gradient with weight decay
646                            let new_val = current_val + lr * mom[dim] - weight_decay * current_val;
647
648                            // Threshold to sparse ternary
649                            if new_val > 0.3 {
650                                new_pos.push(dim);
651                            } else if new_val < -0.3 {
652                                new_neg.push(dim);
653                            }
654                        }
655
656                        basis.vector.pos = new_pos;
657                        basis.vector.neg = new_neg;
658                    }
659                }
660            }
661        }
662    }
663
664    /// Infer semantic variables from an encoded vector
665    ///
666    /// This performs factorization and returns a semantic interpretation
667    /// of the encoded content.
668    pub fn infer_semantics(&mut self, vec: &SparseVec) -> SemanticInference {
669        let result = self.factorize(vec);
670
671        let mut inferred_variables = HashMap::new();
672        let mut confidence_scores = HashMap::new();
673
674        for (name, factor) in &result.factors {
675            // Get the semantic label from the codebook
676            if let Some(codebook) = self.codebooks.get(name) {
677                if let Some(basis) = codebook
678                    .basis_vectors
679                    .iter()
680                    .find(|b| b.id == factor.best_match_id)
681                {
682                    let label = basis
683                        .label
684                        .clone()
685                        .unwrap_or_else(|| format!("id_{}", basis.id));
686                    inferred_variables.insert(name.clone(), label);
687                    confidence_scores.insert(name.clone(), factor.confidence);
688                }
689            }
690        }
691
692        SemanticInference {
693            variables: inferred_variables,
694            confidences: confidence_scores,
695            raw_factors: result.factors,
696            reconstruction_quality: result.reconstruction_quality,
697        }
698    }
699}
700
701/// A training example for the resonator
702#[derive(Clone, Debug)]
703pub struct TrainingExample {
704    /// The composite vector (binding of factors)
705    pub composite: SparseVec,
706    /// Expected factor IDs by codebook name
707    pub expected_factors: HashMap<String, u32>,
708}
709
710impl TrainingExample {
711    /// Create a new training example
712    pub fn new(composite: SparseVec, expected_factors: HashMap<String, u32>) -> Self {
713        Self {
714            composite,
715            expected_factors,
716        }
717    }
718
719    /// Create training example by binding vectors from codebooks
720    pub fn from_codebooks(
721        codebooks: &HashMap<String, &Codebook>,
722        factor_ids: &HashMap<String, u32>,
723    ) -> Option<Self> {
724        let mut composite: Option<SparseVec> = None;
725
726        for (name, id) in factor_ids {
727            if let Some(codebook) = codebooks.get(name) {
728                if let Some(basis) = codebook.basis_vectors.iter().find(|b| b.id == *id) {
729                    composite = Some(match composite {
730                        None => basis.vector.clone(),
731                        Some(c) => c.bind(&basis.vector),
732                    });
733                }
734            }
735        }
736
737        composite.map(|c| Self::new(c, factor_ids.clone()))
738    }
739}
740
741/// Result of training
742#[derive(Clone, Debug)]
743pub struct TrainingResult {
744    /// Final loss value
745    pub final_loss: f64,
746    /// Loss over epochs
747    pub loss_history: Vec<f64>,
748    /// Epochs completed
749    pub epochs_completed: usize,
750}
751
752/// Result of semantic inference
753#[derive(Clone, Debug)]
754pub struct SemanticInference {
755    /// Inferred semantic variables by name
756    pub variables: HashMap<String, String>,
757    /// Confidence scores for each variable
758    pub confidences: HashMap<String, f64>,
759    /// Raw factor results
760    pub raw_factors: HashMap<String, RecoveredFactor>,
761    /// Overall reconstruction quality
762    pub reconstruction_quality: f64,
763}
764
765/// Perform weighted bundle of vectors
766fn weighted_bundle(weighted_vecs: &[(SparseVec, f64)]) -> SparseVec {
767    if weighted_vecs.is_empty() {
768        return SparseVec::random();
769    }
770
771    // Accumulate weighted votes for each dimension
772    let mut dim_votes: Vec<f64> = vec![0.0; DIM];
773
774    for (vec, weight) in weighted_vecs {
775        for &pos in &vec.pos {
776            if pos < DIM {
777                dim_votes[pos] += weight;
778            }
779        }
780        for &neg in &vec.neg {
781            if neg < DIM {
782                dim_votes[neg] -= weight;
783            }
784        }
785    }
786
787    // Threshold to get sparse ternary result
788    let threshold = 0.3;
789    let mut pos = Vec::new();
790    let mut neg = Vec::new();
791
792    for (dim, &vote) in dim_votes.iter().enumerate() {
793        if vote > threshold {
794            pos.push(dim);
795        } else if vote < -threshold {
796            neg.push(dim);
797        }
798    }
799
800    SparseVec { pos, neg }
801}
802
803/// Add gradient to move vector toward target (standalone function)
804fn add_gradient_toward(grad: &mut [f64], target: &SparseVec, current: &SparseVec) {
805    // Simple gradient: increase overlap with target
806    for &pos in &target.pos {
807        if pos < DIM {
808            grad[pos] += 1.0;
809        }
810    }
811    for &neg in &target.neg {
812        if neg < DIM {
813            grad[neg] -= 1.0;
814        }
815    }
816
817    // Decrease current-specific dimensions not in target
818    for &pos in &current.pos {
819        if !target.pos.contains(&pos) && pos < DIM {
820            grad[pos] -= 0.5;
821        }
822    }
823    for &neg in &current.neg {
824        if !target.neg.contains(&neg) && neg < DIM {
825            grad[neg] += 0.5;
826        }
827    }
828}
829
830/// Add gradient to move vector away from target (standalone function)
831fn add_gradient_away(grad: &mut [f64], target: &SparseVec) {
832    // Push away from target
833    for &pos in &target.pos {
834        if pos < DIM {
835            grad[pos] -= 0.5;
836        }
837    }
838    for &neg in &target.neg {
839        if neg < DIM {
840            grad[neg] += 0.5;
841        }
842    }
843}
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848    use crate::codebook::BasisVector;
849
850    #[test]
851    fn test_resonator_config_default() {
852        let config = ResonatorConfig::default();
853        assert_eq!(config.max_iterations, 50);
854        assert!((config.convergence_threshold - 0.99).abs() < 0.001);
855    }
856
857    #[test]
858    fn test_resonator_new() {
859        let config = ResonatorConfig::default();
860        let resonator = Resonator::new(config);
861        assert!(resonator.codebooks.is_empty());
862        assert!(resonator.factor_order.is_empty());
863    }
864
865    #[test]
866    fn test_resonator_add_codebook() {
867        let mut resonator = Resonator::new(ResonatorConfig::default());
868        let mut codebook = Codebook::new(DIM);
869        codebook.basis_vectors.push(BasisVector {
870            id: 0,
871            vector: SparseVec::random(),
872            label: Some("test".to_string()),
873            weight: 1.0,
874        });
875
876        resonator.add_codebook("type", codebook);
877
878        assert!(resonator.get_codebook("type").is_some());
879        assert_eq!(resonator.factor_order, vec!["type"]);
880    }
881
882    #[test]
883    fn test_factorization_empty_resonator() {
884        let mut resonator = Resonator::new(ResonatorConfig::default());
885        let vec = SparseVec::random();
886        let result = resonator.factorize(&vec);
887
888        assert!(result.converged);
889        assert_eq!(result.iterations, 0);
890        assert!(result.factors.is_empty());
891    }
892
893    #[test]
894    fn test_factorization_single_factor() {
895        let mut resonator = Resonator::new(ResonatorConfig::default());
896
897        // Create a codebook with a few basis vectors
898        let mut codebook = Codebook::new(DIM);
899        let target_vec = SparseVec::random();
900        codebook.basis_vectors.push(BasisVector {
901            id: 1,
902            vector: target_vec.clone(),
903            label: Some("target".to_string()),
904            weight: 1.0,
905        });
906        codebook.basis_vectors.push(BasisVector {
907            id: 2,
908            vector: SparseVec::random(),
909            label: Some("distractor".to_string()),
910            weight: 1.0,
911        });
912
913        resonator.add_codebook("type", codebook);
914
915        // The "composite" is just the target itself (no binding)
916        let result = resonator.factorize(&target_vec);
917
918        assert!(result.factors.contains_key("type"));
919        let factor = result.factors.get("type").unwrap();
920
921        // Should match the target with high confidence
922        assert_eq!(factor.best_match_id, 1);
923        assert!(factor.confidence > 0.5);
924    }
925
926    #[test]
927    fn test_training_example_creation() {
928        let vec = SparseVec::random();
929        let mut factors = HashMap::new();
930        factors.insert("type".to_string(), 1u32);
931
932        let example = TrainingExample::new(vec.clone(), factors);
933        assert_eq!(example.composite.pos, vec.pos);
934        assert_eq!(example.expected_factors.get("type"), Some(&1u32));
935    }
936
937    #[test]
938    fn test_weighted_bundle() {
939        let v1 = SparseVec {
940            pos: vec![1, 2, 3],
941            neg: vec![4, 5],
942        };
943        let v2 = SparseVec {
944            pos: vec![1, 6],
945            neg: vec![4, 7],
946        };
947
948        let result = weighted_bundle(&[(v1, 0.6), (v2, 0.4)]);
949
950        // Dimension 1 should be positive (0.6 + 0.4 = 1.0 > 0.3)
951        assert!(result.pos.contains(&1));
952        // Dimension 4 should be negative (-0.6 - 0.4 = -1.0 < -0.3)
953        assert!(result.neg.contains(&4));
954    }
955
956    #[test]
957    fn test_semantic_inference() {
958        let mut resonator = Resonator::new(ResonatorConfig::default());
959
960        let mut codebook = Codebook::new(DIM);
961        let vec = SparseVec::random();
962        codebook.basis_vectors.push(BasisVector {
963            id: 42,
964            vector: vec.clone(),
965            label: Some("semantic_label".to_string()),
966            weight: 1.0,
967        });
968
969        resonator.add_codebook("content", codebook);
970
971        let inference = resonator.infer_semantics(&vec);
972
973        assert!(inference.variables.contains_key("content"));
974        assert!(inference.confidences.contains_key("content"));
975    }
976}