1use std::collections::HashMap;
51
52use crate::codebook::Codebook;
53use crate::vsa::{SparseVec, DIM};
54
55#[derive(Clone, Debug)]
57pub struct ResonatorConfig {
58 pub max_iterations: usize,
60 pub convergence_threshold: f64,
62 pub learning_rate: f64,
64 pub momentum: f64,
66 pub weight_decay: f64,
68 pub temperature: f64,
70 pub soft_cleanup: bool,
72 pub soft_cleanup_top_k: usize,
74}
75
76impl Default for ResonatorConfig {
77 fn default() -> Self {
78 Self {
79 max_iterations: 50,
80 convergence_threshold: 0.99,
81 learning_rate: 0.01,
82 momentum: 0.9,
83 weight_decay: 1e-5,
84 temperature: 0.1,
85 soft_cleanup: true,
86 soft_cleanup_top_k: 8,
87 }
88 }
89}
90
91#[derive(Clone, Debug)]
93pub struct FactorizationResult {
94 pub factors: HashMap<String, RecoveredFactor>,
96 pub iterations: usize,
98 pub converged: bool,
100 pub reconstruction_quality: f64,
102 pub convergence_history: Vec<f64>,
104}
105
106#[derive(Clone, Debug)]
108pub struct RecoveredFactor {
109 pub best_match_id: u32,
111 pub best_match: Option<SparseVec>,
113 pub confidence: f64,
115 pub estimate: SparseVec,
117 pub candidates: Vec<(u32, f64)>,
119}
120
121#[derive(Clone, Debug)]
123struct CodebookGradient {
124 gradients: HashMap<u32, Vec<f64>>,
126 momentum: HashMap<u32, Vec<f64>>,
128}
129
130impl CodebookGradient {
131 fn new() -> Self {
132 Self {
133 gradients: HashMap::new(),
134 momentum: HashMap::new(),
135 }
136 }
137
138 fn zero_gradients(&mut self) {
139 for grad in self.gradients.values_mut() {
140 grad.fill(0.0);
141 }
142 }
143}
144
145pub struct Resonator {
147 config: ResonatorConfig,
149 codebooks: HashMap<String, Codebook>,
151 factor_order: Vec<String>,
153 gradients: HashMap<String, CodebookGradient>,
155 stats: ResonatorStats,
157}
158
159#[derive(Clone, Debug, Default)]
161pub struct ResonatorStats {
162 pub total_factorizations: u64,
164 pub converged_count: u64,
166 pub avg_iterations: f64,
168 pub avg_reconstruction_quality: f64,
170 pub training_steps: u64,
172 pub current_loss: f64,
174}
175
176impl Resonator {
177 pub fn new(config: ResonatorConfig) -> Self {
179 Self {
180 config,
181 codebooks: HashMap::new(),
182 factor_order: Vec::new(),
183 gradients: HashMap::new(),
184 stats: ResonatorStats::default(),
185 }
186 }
187
188 pub fn add_codebook(&mut self, name: &str, codebook: Codebook) {
190 self.codebooks.insert(name.to_string(), codebook);
191 self.factor_order.push(name.to_string());
192 self.gradients
193 .insert(name.to_string(), CodebookGradient::new());
194 }
195
196 pub fn get_codebook(&self, name: &str) -> Option<&Codebook> {
198 self.codebooks.get(name)
199 }
200
201 pub fn get_codebook_mut(&mut self, name: &str) -> Option<&mut Codebook> {
203 self.codebooks.get_mut(name)
204 }
205
206 pub fn stats(&self) -> &ResonatorStats {
208 &self.stats
209 }
210
211 pub fn factorize(&mut self, composite: &SparseVec) -> FactorizationResult {
216 self.factorize_with_iterations(composite, self.config.max_iterations)
217 }
218
219 pub fn factorize_with_iterations(
221 &mut self,
222 composite: &SparseVec,
223 max_iterations: usize,
224 ) -> FactorizationResult {
225 if self.factor_order.is_empty() {
226 return FactorizationResult {
227 factors: HashMap::new(),
228 iterations: 0,
229 converged: true,
230 reconstruction_quality: 0.0,
231 convergence_history: Vec::new(),
232 };
233 }
234
235 let mut estimates: HashMap<String, SparseVec> = self
237 .factor_order
238 .iter()
239 .map(|name| (name.clone(), SparseVec::random()))
240 .collect();
241
242 let mut convergence_history = Vec::new();
243 let mut prev_estimates = estimates.clone();
244 let mut converged = false;
245
246 for iteration in 0..max_iterations {
247 for name in &self.factor_order.clone() {
249 let mut unbound = composite.clone();
251 for (other_name, other_estimate) in &estimates {
252 if other_name != name {
253 unbound = unbound.bind(other_estimate);
255 }
256 }
257
258 let cleaned = if let Some(codebook) = self.codebooks.get(name) {
260 self.cleanup(&unbound, codebook)
261 } else {
262 unbound
263 };
264
265 estimates.insert(name.clone(), cleaned);
266 }
267
268 let mut min_similarity = 1.0f64;
270 for name in &self.factor_order {
271 let curr = estimates.get(name).unwrap();
272 let prev = prev_estimates.get(name).unwrap();
273 let sim = curr.cosine(prev);
274 min_similarity = min_similarity.min(sim);
275 }
276 convergence_history.push(min_similarity);
277
278 if min_similarity >= self.config.convergence_threshold {
279 converged = true;
280 self.stats.converged_count += 1;
281 }
282
283 prev_estimates = estimates.clone();
284
285 if converged {
286 let reconstruction = self.reconstruct(&estimates);
288 let quality = reconstruction.cosine(composite);
289
290 self.stats.total_factorizations += 1;
291 self.stats.avg_reconstruction_quality = (self.stats.avg_reconstruction_quality
292 * (self.stats.total_factorizations - 1) as f64
293 + quality)
294 / self.stats.total_factorizations as f64;
295 self.stats.avg_iterations = (self.stats.avg_iterations
296 * (self.stats.total_factorizations - 1) as f64
297 + (iteration + 1) as f64)
298 / self.stats.total_factorizations as f64;
299
300 return self.build_result(
301 estimates,
302 iteration + 1,
303 true,
304 quality,
305 convergence_history,
306 );
307 }
308 }
309
310 self.stats.total_factorizations += 1;
312 let reconstruction = self.reconstruct(&estimates);
313 let quality = reconstruction.cosine(composite);
314
315 self.build_result(
316 estimates,
317 max_iterations,
318 false,
319 quality,
320 convergence_history,
321 )
322 }
323
324 fn cleanup(&self, vec: &SparseVec, codebook: &Codebook) -> SparseVec {
326 if codebook.basis_vectors.is_empty() {
327 return vec.clone();
328 }
329
330 if self.config.soft_cleanup {
331 self.soft_cleanup(vec, codebook)
333 } else {
334 self.hard_cleanup(vec, codebook)
336 }
337 }
338
339 fn hard_cleanup(&self, vec: &SparseVec, codebook: &Codebook) -> SparseVec {
341 let mut best_sim = f64::NEG_INFINITY;
342 let mut best_vec = vec.clone();
343
344 for basis in &codebook.basis_vectors {
345 let sim = vec.cosine(&basis.vector);
346 if sim > best_sim {
347 best_sim = sim;
348 best_vec = basis.vector.clone();
349 }
350 }
351
352 best_vec
353 }
354
355 fn soft_cleanup(&self, vec: &SparseVec, codebook: &Codebook) -> SparseVec {
357 let mut similarities: Vec<(usize, f64)> = codebook
359 .basis_vectors
360 .iter()
361 .enumerate()
362 .map(|(i, basis)| (i, vec.cosine(&basis.vector)))
363 .collect();
364
365 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
367
368 let top_k: Vec<_> = similarities
370 .into_iter()
371 .take(self.config.soft_cleanup_top_k)
372 .collect();
373
374 if top_k.is_empty() {
375 return vec.clone();
376 }
377
378 let max_sim = top_k.first().map(|t| t.1).unwrap_or(0.0);
380 let weights: Vec<f64> = top_k
381 .iter()
382 .map(|(_, sim)| ((sim - max_sim) / self.config.temperature).exp())
383 .collect();
384 let weight_sum: f64 = weights.iter().sum();
385
386 if weight_sum == 0.0 {
387 return codebook.basis_vectors[top_k[0].0].vector.clone();
388 }
389
390 let weighted_vecs: Vec<_> = top_k
392 .iter()
393 .zip(weights.iter())
394 .map(|((idx, _), w)| (codebook.basis_vectors[*idx].vector.clone(), *w / weight_sum))
395 .collect();
396
397 weighted_bundle(&weighted_vecs)
398 }
399
400 fn reconstruct(&self, factors: &HashMap<String, SparseVec>) -> SparseVec {
402 let mut result = SparseVec::random(); let mut first = true;
404
405 for name in &self.factor_order {
406 if let Some(factor) = factors.get(name) {
407 if first {
408 result = factor.clone();
409 first = false;
410 } else {
411 result = result.bind(factor);
412 }
413 }
414 }
415
416 result
417 }
418
419 fn build_result(
421 &self,
422 estimates: HashMap<String, SparseVec>,
423 iterations: usize,
424 converged: bool,
425 quality: f64,
426 convergence_history: Vec<f64>,
427 ) -> FactorizationResult {
428 let mut factors = HashMap::new();
429
430 for (name, estimate) in estimates {
431 if let Some(codebook) = self.codebooks.get(&name) {
432 let mut candidates: Vec<(u32, f64)> = codebook
434 .basis_vectors
435 .iter()
436 .map(|b| (b.id, estimate.cosine(&b.vector)))
437 .collect();
438 candidates
439 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
440
441 let best = candidates.first().cloned().unwrap_or((0, 0.0));
442 let best_match = codebook
443 .basis_vectors
444 .iter()
445 .find(|b| b.id == best.0)
446 .map(|b| b.vector.clone());
447
448 factors.insert(
449 name,
450 RecoveredFactor {
451 best_match_id: best.0,
452 best_match,
453 confidence: best.1,
454 estimate,
455 candidates: candidates.into_iter().take(10).collect(),
456 },
457 );
458 } else {
459 factors.insert(
460 name,
461 RecoveredFactor {
462 best_match_id: 0,
463 best_match: None,
464 confidence: 0.0,
465 estimate,
466 candidates: Vec::new(),
467 },
468 );
469 }
470 }
471
472 FactorizationResult {
473 factors,
474 iterations,
475 converged,
476 reconstruction_quality: quality,
477 convergence_history,
478 }
479 }
480
481 pub fn train(
486 &mut self,
487 training_data: &[TrainingExample],
488 epochs: usize,
489 ) -> Result<TrainingResult, String> {
490 if training_data.is_empty() {
491 return Err("No training data provided".to_string());
492 }
493
494 let mut loss_history = Vec::new();
495 let batch_size = 32.min(training_data.len());
496
497 for _epoch in 0..epochs {
498 let mut epoch_loss = 0.0;
499 let mut batch_count = 0;
500
501 for batch in training_data.chunks(batch_size) {
503 self.zero_gradients();
504
505 let mut batch_loss = 0.0;
506 for example in batch {
507 let result = self.factorize(&example.composite);
509
510 let recon_loss = 1.0 - result.reconstruction_quality;
512 let factor_loss = self.compute_factor_loss(&result, &example.expected_factors);
513 let total_loss = recon_loss + factor_loss;
514 batch_loss += total_loss;
515
516 self.backward(&example.composite, &result, &example.expected_factors);
518 }
519
520 self.apply_gradients(batch.len());
522
523 epoch_loss += batch_loss;
524 batch_count += 1;
525 }
526
527 let avg_loss = epoch_loss / (batch_count * batch_size) as f64;
528 loss_history.push(avg_loss);
529 self.stats.current_loss = avg_loss;
530 self.stats.training_steps += 1;
531 }
532
533 Ok(TrainingResult {
534 final_loss: *loss_history.last().unwrap_or(&0.0),
535 loss_history,
536 epochs_completed: epochs,
537 })
538 }
539
540 fn zero_gradients(&mut self) {
542 for grad in self.gradients.values_mut() {
543 grad.zero_gradients();
544 }
545 }
546
547 fn compute_factor_loss(
549 &self,
550 result: &FactorizationResult,
551 expected: &HashMap<String, u32>,
552 ) -> f64 {
553 let mut loss = 0.0;
554 let mut count = 0;
555
556 for (name, expected_id) in expected {
557 if let Some(factor) = result.factors.get(name) {
558 if factor.best_match_id != *expected_id {
560 loss += 1.0 - factor.confidence;
561 }
562 count += 1;
563 }
564 }
565
566 if count > 0 {
567 loss / count as f64
568 } else {
569 0.0
570 }
571 }
572
573 fn backward(
575 &mut self,
576 _composite: &SparseVec,
577 result: &FactorizationResult,
578 expected: &HashMap<String, u32>,
579 ) {
580 for (name, expected_id) in expected {
583 if let (Some(factor), Some(codebook)) =
584 (result.factors.get(name), self.codebooks.get(name))
585 {
586 if let Some(grad_state) = self.gradients.get_mut(name) {
587 for basis in &codebook.basis_vectors {
589 let grad = grad_state
590 .gradients
591 .entry(basis.id)
592 .or_insert_with(|| vec![0.0; DIM]);
593
594 let sim = factor.estimate.cosine(&basis.vector);
596
597 if basis.id == *expected_id {
598 add_gradient_toward(grad, &factor.estimate, &basis.vector);
600 } else if sim > 0.5 {
601 add_gradient_away(grad, &factor.estimate);
603 }
604 }
605 }
606 }
607 }
608 }
609
610 fn apply_gradients(&mut self, batch_size: usize) {
612 let lr = self.config.learning_rate / batch_size as f64;
613 let momentum = self.config.momentum;
614 let weight_decay = self.config.weight_decay;
615
616 for (name, grad_state) in &mut self.gradients {
617 if let Some(codebook) = self.codebooks.get_mut(name) {
618 for basis in &mut codebook.basis_vectors {
619 if let Some(grad) = grad_state.gradients.get(&basis.id) {
620 let mom = grad_state
622 .momentum
623 .entry(basis.id)
624 .or_insert_with(|| vec![0.0; DIM]);
625
626 let mut new_pos = Vec::new();
628 let mut new_neg = Vec::new();
629
630 for dim in 0..DIM {
631 mom[dim] = momentum * mom[dim] + grad[dim];
633
634 let is_pos = basis.vector.pos.contains(&dim);
636 let is_neg = basis.vector.neg.contains(&dim);
637 let current_val = if is_pos {
638 1.0
639 } else if is_neg {
640 -1.0
641 } else {
642 0.0
643 };
644
645 let new_val = current_val + lr * mom[dim] - weight_decay * current_val;
647
648 if new_val > 0.3 {
650 new_pos.push(dim);
651 } else if new_val < -0.3 {
652 new_neg.push(dim);
653 }
654 }
655
656 basis.vector.pos = new_pos;
657 basis.vector.neg = new_neg;
658 }
659 }
660 }
661 }
662 }
663
664 pub fn infer_semantics(&mut self, vec: &SparseVec) -> SemanticInference {
669 let result = self.factorize(vec);
670
671 let mut inferred_variables = HashMap::new();
672 let mut confidence_scores = HashMap::new();
673
674 for (name, factor) in &result.factors {
675 if let Some(codebook) = self.codebooks.get(name) {
677 if let Some(basis) = codebook
678 .basis_vectors
679 .iter()
680 .find(|b| b.id == factor.best_match_id)
681 {
682 let label = basis
683 .label
684 .clone()
685 .unwrap_or_else(|| format!("id_{}", basis.id));
686 inferred_variables.insert(name.clone(), label);
687 confidence_scores.insert(name.clone(), factor.confidence);
688 }
689 }
690 }
691
692 SemanticInference {
693 variables: inferred_variables,
694 confidences: confidence_scores,
695 raw_factors: result.factors,
696 reconstruction_quality: result.reconstruction_quality,
697 }
698 }
699}
700
701#[derive(Clone, Debug)]
703pub struct TrainingExample {
704 pub composite: SparseVec,
706 pub expected_factors: HashMap<String, u32>,
708}
709
710impl TrainingExample {
711 pub fn new(composite: SparseVec, expected_factors: HashMap<String, u32>) -> Self {
713 Self {
714 composite,
715 expected_factors,
716 }
717 }
718
719 pub fn from_codebooks(
721 codebooks: &HashMap<String, &Codebook>,
722 factor_ids: &HashMap<String, u32>,
723 ) -> Option<Self> {
724 let mut composite: Option<SparseVec> = None;
725
726 for (name, id) in factor_ids {
727 if let Some(codebook) = codebooks.get(name) {
728 if let Some(basis) = codebook.basis_vectors.iter().find(|b| b.id == *id) {
729 composite = Some(match composite {
730 None => basis.vector.clone(),
731 Some(c) => c.bind(&basis.vector),
732 });
733 }
734 }
735 }
736
737 composite.map(|c| Self::new(c, factor_ids.clone()))
738 }
739}
740
741#[derive(Clone, Debug)]
743pub struct TrainingResult {
744 pub final_loss: f64,
746 pub loss_history: Vec<f64>,
748 pub epochs_completed: usize,
750}
751
752#[derive(Clone, Debug)]
754pub struct SemanticInference {
755 pub variables: HashMap<String, String>,
757 pub confidences: HashMap<String, f64>,
759 pub raw_factors: HashMap<String, RecoveredFactor>,
761 pub reconstruction_quality: f64,
763}
764
765fn weighted_bundle(weighted_vecs: &[(SparseVec, f64)]) -> SparseVec {
767 if weighted_vecs.is_empty() {
768 return SparseVec::random();
769 }
770
771 let mut dim_votes: Vec<f64> = vec![0.0; DIM];
773
774 for (vec, weight) in weighted_vecs {
775 for &pos in &vec.pos {
776 if pos < DIM {
777 dim_votes[pos] += weight;
778 }
779 }
780 for &neg in &vec.neg {
781 if neg < DIM {
782 dim_votes[neg] -= weight;
783 }
784 }
785 }
786
787 let threshold = 0.3;
789 let mut pos = Vec::new();
790 let mut neg = Vec::new();
791
792 for (dim, &vote) in dim_votes.iter().enumerate() {
793 if vote > threshold {
794 pos.push(dim);
795 } else if vote < -threshold {
796 neg.push(dim);
797 }
798 }
799
800 SparseVec { pos, neg }
801}
802
803fn add_gradient_toward(grad: &mut [f64], target: &SparseVec, current: &SparseVec) {
805 for &pos in &target.pos {
807 if pos < DIM {
808 grad[pos] += 1.0;
809 }
810 }
811 for &neg in &target.neg {
812 if neg < DIM {
813 grad[neg] -= 1.0;
814 }
815 }
816
817 for &pos in ¤t.pos {
819 if !target.pos.contains(&pos) && pos < DIM {
820 grad[pos] -= 0.5;
821 }
822 }
823 for &neg in ¤t.neg {
824 if !target.neg.contains(&neg) && neg < DIM {
825 grad[neg] += 0.5;
826 }
827 }
828}
829
830fn add_gradient_away(grad: &mut [f64], target: &SparseVec) {
832 for &pos in &target.pos {
834 if pos < DIM {
835 grad[pos] -= 0.5;
836 }
837 }
838 for &neg in &target.neg {
839 if neg < DIM {
840 grad[neg] += 0.5;
841 }
842 }
843}
844
845#[cfg(test)]
846mod tests {
847 use super::*;
848 use crate::codebook::BasisVector;
849
850 #[test]
851 fn test_resonator_config_default() {
852 let config = ResonatorConfig::default();
853 assert_eq!(config.max_iterations, 50);
854 assert!((config.convergence_threshold - 0.99).abs() < 0.001);
855 }
856
857 #[test]
858 fn test_resonator_new() {
859 let config = ResonatorConfig::default();
860 let resonator = Resonator::new(config);
861 assert!(resonator.codebooks.is_empty());
862 assert!(resonator.factor_order.is_empty());
863 }
864
865 #[test]
866 fn test_resonator_add_codebook() {
867 let mut resonator = Resonator::new(ResonatorConfig::default());
868 let mut codebook = Codebook::new(DIM);
869 codebook.basis_vectors.push(BasisVector {
870 id: 0,
871 vector: SparseVec::random(),
872 label: Some("test".to_string()),
873 weight: 1.0,
874 });
875
876 resonator.add_codebook("type", codebook);
877
878 assert!(resonator.get_codebook("type").is_some());
879 assert_eq!(resonator.factor_order, vec!["type"]);
880 }
881
882 #[test]
883 fn test_factorization_empty_resonator() {
884 let mut resonator = Resonator::new(ResonatorConfig::default());
885 let vec = SparseVec::random();
886 let result = resonator.factorize(&vec);
887
888 assert!(result.converged);
889 assert_eq!(result.iterations, 0);
890 assert!(result.factors.is_empty());
891 }
892
893 #[test]
894 fn test_factorization_single_factor() {
895 let mut resonator = Resonator::new(ResonatorConfig::default());
896
897 let mut codebook = Codebook::new(DIM);
899 let target_vec = SparseVec::random();
900 codebook.basis_vectors.push(BasisVector {
901 id: 1,
902 vector: target_vec.clone(),
903 label: Some("target".to_string()),
904 weight: 1.0,
905 });
906 codebook.basis_vectors.push(BasisVector {
907 id: 2,
908 vector: SparseVec::random(),
909 label: Some("distractor".to_string()),
910 weight: 1.0,
911 });
912
913 resonator.add_codebook("type", codebook);
914
915 let result = resonator.factorize(&target_vec);
917
918 assert!(result.factors.contains_key("type"));
919 let factor = result.factors.get("type").unwrap();
920
921 assert_eq!(factor.best_match_id, 1);
923 assert!(factor.confidence > 0.5);
924 }
925
926 #[test]
927 fn test_training_example_creation() {
928 let vec = SparseVec::random();
929 let mut factors = HashMap::new();
930 factors.insert("type".to_string(), 1u32);
931
932 let example = TrainingExample::new(vec.clone(), factors);
933 assert_eq!(example.composite.pos, vec.pos);
934 assert_eq!(example.expected_factors.get("type"), Some(&1u32));
935 }
936
937 #[test]
938 fn test_weighted_bundle() {
939 let v1 = SparseVec {
940 pos: vec![1, 2, 3],
941 neg: vec![4, 5],
942 };
943 let v2 = SparseVec {
944 pos: vec![1, 6],
945 neg: vec![4, 7],
946 };
947
948 let result = weighted_bundle(&[(v1, 0.6), (v2, 0.4)]);
949
950 assert!(result.pos.contains(&1));
952 assert!(result.neg.contains(&4));
954 }
955
956 #[test]
957 fn test_semantic_inference() {
958 let mut resonator = Resonator::new(ResonatorConfig::default());
959
960 let mut codebook = Codebook::new(DIM);
961 let vec = SparseVec::random();
962 codebook.basis_vectors.push(BasisVector {
963 id: 42,
964 vector: vec.clone(),
965 label: Some("semantic_label".to_string()),
966 weight: 1.0,
967 });
968
969 resonator.add_codebook("content", codebook);
970
971 let inference = resonator.infer_semantics(&vec);
972
973 assert!(inference.variables.contains_key("content"));
974 assert!(inference.confidences.contains_key("content"));
975 }
976}