1use std::collections::VecDeque;
15
16use rand::rngs::StdRng;
17use serde::{Deserialize, Serialize};
18
19use crate::error::PcError;
20use crate::linalg::cpu::CpuLinAlg;
21use crate::linalg::LinAlg;
22use crate::mlp_critic::{MlpCritic, MlpCriticConfig};
23use crate::pc_actor::{InferResult, PcActor, PcActorConfig, SelectionMode};
24
25fn default_gamma() -> f64 {
27 0.95
28}
29
30fn default_surprise_low() -> f64 {
32 0.02
33}
34
35fn default_surprise_high() -> f64 {
37 0.15
38}
39
40fn default_adaptive_surprise() -> bool {
42 true
43}
44
45fn default_surprise_buffer_size() -> usize {
47 400
48}
49
50fn default_entropy_coeff() -> f64 {
52 0.01
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct PcActorCriticConfig {
94 pub actor: PcActorConfig,
96 pub critic: MlpCriticConfig,
98 #[serde(default = "default_gamma")]
100 pub gamma: f64,
101 #[serde(default = "default_surprise_low")]
103 pub surprise_low: f64,
104 #[serde(default = "default_surprise_high")]
106 pub surprise_high: f64,
107 #[serde(default = "default_adaptive_surprise")]
109 pub adaptive_surprise: bool,
110 #[serde(default = "default_surprise_buffer_size")]
113 pub surprise_buffer_size: usize,
114 #[serde(default = "default_entropy_coeff")]
116 pub entropy_coeff: f64,
117}
118
119#[derive(Debug, Clone)]
123pub struct TrajectoryStep<L: LinAlg = CpuLinAlg> {
124 pub input: L::Vector,
126 pub latent_concat: L::Vector,
128 pub y_conv: L::Vector,
130 pub hidden_states: Vec<L::Vector>,
132 pub prediction_errors: Vec<L::Vector>,
134 pub tanh_components: Vec<Option<L::Vector>>,
136 pub action: usize,
138 pub valid_actions: Vec<usize>,
140 pub reward: f64,
142 pub surprise_score: f64,
144 pub steps_used: usize,
146}
147
148#[derive(Debug, Clone)]
167pub struct ActivationCache<L: LinAlg = CpuLinAlg> {
168 layers: Vec<Vec<L::Vector>>,
170}
171
172impl<L: LinAlg> ActivationCache<L> {
173 pub fn new(num_layers: usize) -> Self {
179 Self {
180 layers: (0..num_layers).map(|_| Vec::new()).collect(),
181 }
182 }
183
184 pub fn batch_size(&self) -> usize {
186 self.layers.first().map_or(0, |l| l.len())
187 }
188
189 pub fn num_layers(&self) -> usize {
191 self.layers.len()
192 }
193
194 pub fn record(&mut self, hidden_states: &[L::Vector]) {
200 for (layer, state) in self.layers.iter_mut().zip(hidden_states.iter()) {
201 layer.push(state.clone());
202 }
203 }
204
205 pub fn layer(&self, layer_idx: usize) -> &[L::Vector] {
211 &self.layers[layer_idx]
212 }
213}
214
215#[derive(Debug)]
222pub struct PcActorCritic<L: LinAlg = CpuLinAlg> {
223 pub(crate) actor: PcActor<L>,
225 pub(crate) critic: MlpCritic<L>,
227 pub config: PcActorCriticConfig,
229 rng: StdRng,
231 surprise_buffer: VecDeque<f64>,
233}
234
235impl<L: LinAlg> PcActorCritic<L> {
236 pub fn new(config: PcActorCriticConfig, seed: u64) -> Result<Self, PcError> {
247 if !(0.0..=1.0).contains(&config.gamma) {
248 return Err(PcError::ConfigValidation(format!(
249 "gamma must be in [0.0, 1.0], got {}",
250 config.gamma
251 )));
252 }
253 if config.adaptive_surprise && config.surprise_buffer_size < 10 {
254 return Err(PcError::ConfigValidation(format!(
255 "surprise_buffer_size must be >= 10 when adaptive_surprise is enabled, got {}",
256 config.surprise_buffer_size
257 )));
258 }
259
260 use rand::SeedableRng;
261 let mut rng = StdRng::seed_from_u64(seed);
262 let actor = PcActor::<L>::new(config.actor.clone(), &mut rng)?;
263 let critic = MlpCritic::<L>::new(config.critic.clone(), &mut rng)?;
264 Ok(Self {
265 actor,
266 critic,
267 config,
268 rng,
269 surprise_buffer: VecDeque::new(),
270 })
271 }
272
273 #[allow(clippy::too_many_arguments)]
293 pub fn crossover(
294 parent_a: &PcActorCritic<L>,
295 parent_b: &PcActorCritic<L>,
296 actor_cache_a: &ActivationCache<L>,
297 actor_cache_b: &ActivationCache<L>,
298 critic_cache_a: &ActivationCache<L>,
299 critic_cache_b: &ActivationCache<L>,
300 alpha: f64,
301 child_config: PcActorCriticConfig,
302 seed: u64,
303 ) -> Result<Self, PcError> {
304 if actor_cache_a.batch_size() != actor_cache_b.batch_size() {
306 return Err(PcError::DimensionMismatch {
307 expected: actor_cache_a.batch_size(),
308 got: actor_cache_b.batch_size(),
309 context: "actor activation cache batch sizes must match for crossover",
310 });
311 }
312 if critic_cache_a.batch_size() != critic_cache_b.batch_size() {
314 return Err(PcError::DimensionMismatch {
315 expected: critic_cache_a.batch_size(),
316 got: critic_cache_b.batch_size(),
317 context: "critic activation cache batch sizes must match for crossover",
318 });
319 }
320
321 let actor_cache_mats_a = cache_to_matrices::<L>(actor_cache_a);
323 let actor_cache_mats_b = cache_to_matrices::<L>(actor_cache_b);
324 let critic_cache_mats_a = cache_to_matrices::<L>(critic_cache_a);
325 let critic_cache_mats_b = cache_to_matrices::<L>(critic_cache_b);
326
327 use rand::SeedableRng;
328 let mut rng = StdRng::seed_from_u64(seed);
329
330 let actor = PcActor::<L>::crossover(
332 &parent_a.actor,
333 &parent_b.actor,
334 &actor_cache_mats_a,
335 &actor_cache_mats_b,
336 alpha,
337 child_config.actor.clone(),
338 &mut rng,
339 )?;
340
341 let critic = MlpCritic::<L>::crossover(
343 &parent_a.critic,
344 &parent_b.critic,
345 &critic_cache_mats_a,
346 &critic_cache_mats_b,
347 alpha,
348 child_config.critic.clone(),
349 &mut rng,
350 )?;
351
352 Ok(Self {
353 actor,
354 critic,
355 config: child_config,
356 rng,
357 surprise_buffer: VecDeque::new(),
358 })
359 }
360
361 pub fn from_parts(
370 config: PcActorCriticConfig,
371 actor: PcActor<L>,
372 critic: MlpCritic<L>,
373 rng: StdRng,
374 ) -> Self {
375 Self {
376 actor,
377 critic,
378 config,
379 rng,
380 surprise_buffer: VecDeque::new(),
381 }
382 }
383
384 pub fn infer(&self, input: &[f64]) -> InferResult<L> {
397 self.actor.infer(input)
398 }
399
400 pub fn act(
415 &mut self,
416 input: &[f64],
417 valid_actions: &[usize],
418 mode: SelectionMode,
419 ) -> (usize, InferResult<L>) {
420 let infer_result = self.actor.infer(input);
421 let action =
422 self.actor
423 .select_action(&infer_result.y_conv, valid_actions, mode, &mut self.rng);
424 (action, infer_result)
425 }
426
427 pub fn learn(&mut self, trajectory: &[TrajectoryStep<L>]) -> f64 {
440 if trajectory.is_empty() {
441 return 0.0;
442 }
443
444 let n = trajectory.len();
445
446 let mut returns = vec![0.0; n];
448 returns[n - 1] = trajectory[n - 1].reward;
449 for t in (0..n - 1).rev() {
450 returns[t] = trajectory[t].reward + self.config.gamma * returns[t + 1];
451 }
452
453 let mut total_loss = 0.0;
454
455 for (t, step) in trajectory.iter().enumerate() {
456 let input_vec = L::vec_to_vec(&step.input);
458 let latent_vec = L::vec_to_vec(&step.latent_concat);
459 let mut critic_input = input_vec.clone();
460 critic_input.extend_from_slice(&latent_vec);
461
462 let value = self.critic.forward(&critic_input);
464 let advantage = returns[t] - value;
465
466 let loss = self.critic.update(&critic_input, returns[t]);
468 total_loss += loss;
469
470 let y_conv_vec = L::vec_to_vec(&step.y_conv);
472 let scaled: Vec<f64> = y_conv_vec
473 .iter()
474 .map(|&v| v / self.actor.config.temperature)
475 .collect();
476 let scaled_l = L::vec_from_slice(&scaled);
477 let pi_l = L::softmax_masked(&scaled_l, &step.valid_actions);
478 let pi = L::vec_to_vec(&pi_l);
479
480 let mut delta = vec![0.0; pi.len()];
481 for &i in &step.valid_actions {
482 delta[i] = pi[i];
483 }
484 delta[step.action] -= 1.0;
485
486 for &i in &step.valid_actions {
488 delta[i] *= advantage;
489 }
490
491 for &i in &step.valid_actions {
493 let log_pi = (pi[i].max(1e-10)).ln();
494 delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
495 }
496
497 let s_scale = self.surprise_scale(step.surprise_score);
499
500 let stored_infer = InferResult {
501 y_conv: step.y_conv.clone(),
502 latent_concat: step.latent_concat.clone(),
503 hidden_states: step.hidden_states.clone(),
504 prediction_errors: step.prediction_errors.clone(),
505 surprise_score: step.surprise_score,
506 steps_used: step.steps_used,
507 converged: false,
508 tanh_components: step.tanh_components.clone(),
509 };
510 self.actor
511 .update_weights(&delta, &stored_infer, &input_vec, s_scale);
512
513 if self.config.adaptive_surprise {
515 self.push_surprise(step.surprise_score);
516 }
517 }
518
519 total_loss / n as f64
520 }
521
522 #[allow(clippy::too_many_arguments)]
539 pub fn learn_continuous(
540 &mut self,
541 input: &[f64],
542 infer: &InferResult<L>,
543 action: usize,
544 valid_actions: &[usize],
545 reward: f64,
546 next_input: &[f64],
547 next_infer: &InferResult<L>,
548 terminal: bool,
549 ) -> f64 {
550 let latent_vec = L::vec_to_vec(&infer.latent_concat);
552 let mut critic_input = input.to_vec();
553 critic_input.extend_from_slice(&latent_vec);
554
555 let next_latent_vec = L::vec_to_vec(&next_infer.latent_concat);
556 let mut next_critic_input = next_input.to_vec();
557 next_critic_input.extend_from_slice(&next_latent_vec);
558
559 let v_s = self.critic.forward(&critic_input);
560 let v_next = if terminal {
561 0.0
562 } else {
563 self.critic.forward(&next_critic_input)
564 };
565
566 let target = reward
567 + if terminal {
568 0.0
569 } else {
570 self.config.gamma * v_next
571 };
572 let td_error = target - v_s;
573
574 let loss = self.critic.update(&critic_input, target);
576
577 let y_conv_vec = L::vec_to_vec(&infer.y_conv);
579 let scaled: Vec<f64> = y_conv_vec
580 .iter()
581 .map(|&v| v / self.actor.config.temperature)
582 .collect();
583 let scaled_l = L::vec_from_slice(&scaled);
584 let pi_l = L::softmax_masked(&scaled_l, valid_actions);
585 let pi = L::vec_to_vec(&pi_l);
586
587 let mut delta = vec![0.0; pi.len()];
588 for &i in valid_actions {
589 delta[i] = pi[i];
590 }
591 delta[action] -= 1.0;
592
593 for &i in valid_actions {
594 delta[i] *= td_error;
595 }
596
597 for &i in valid_actions {
599 let log_pi = (pi[i].max(1e-10)).ln();
600 delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
601 }
602
603 let s_scale = self.surprise_scale(infer.surprise_score);
604 self.actor.update_weights(&delta, infer, input, s_scale);
605
606 if self.config.adaptive_surprise {
607 self.push_surprise(infer.surprise_score);
608 }
609
610 loss
611 }
612
613 pub fn surprise_scale(&self, surprise: f64) -> f64 {
622 let (low, high) = if self.config.adaptive_surprise && self.surprise_buffer.len() >= 10 {
623 let mean = self.surprise_buffer.iter().sum::<f64>() / self.surprise_buffer.len() as f64;
624 let variance = self
625 .surprise_buffer
626 .iter()
627 .map(|&s| (s - mean) * (s - mean))
628 .sum::<f64>()
629 / self.surprise_buffer.len() as f64;
630 let std = variance.sqrt();
631 let lo = (mean - 0.5 * std).max(0.0);
632 let hi = mean + 1.5 * std;
633 (lo, hi)
634 } else {
635 (self.config.surprise_low, self.config.surprise_high)
636 };
637
638 if surprise <= low {
639 0.1
640 } else if surprise >= high {
641 2.0
642 } else {
643 let t = (surprise - low) / (high - low);
645 0.1 + t * (2.0 - 0.1)
646 }
647 }
648
649 fn push_surprise(&mut self, surprise: f64) {
651 if self.surprise_buffer.len() >= self.config.surprise_buffer_size {
652 self.surprise_buffer.pop_front();
653 }
654 self.surprise_buffer.push_back(surprise);
655 }
656}
657
658fn cache_to_matrices<L: LinAlg>(cache: &ActivationCache<L>) -> Vec<L::Matrix> {
661 let num_layers = cache.num_layers();
662 let batch_size = cache.batch_size();
663 let mut matrices = Vec::with_capacity(num_layers);
664
665 for layer_idx in 0..num_layers {
666 let samples = cache.layer(layer_idx);
667 if samples.is_empty() {
668 matrices.push(L::zeros_mat(0, 0));
669 continue;
670 }
671 let n_neurons = L::vec_len(&samples[0]);
672 let mut mat = L::zeros_mat(batch_size, n_neurons);
673 for (r, sample) in samples.iter().enumerate() {
674 for c in 0..n_neurons {
675 L::mat_set(&mut mat, r, c, L::vec_get(sample, c));
676 }
677 }
678 matrices.push(mat);
679 }
680
681 matrices
682}
683
684#[cfg(test)]
685mod tests {
686 use super::*;
687 use crate::activation::Activation;
688 use crate::layer::LayerDef;
689 use crate::pc_actor::SelectionMode;
690
691 fn default_config() -> PcActorCriticConfig {
692 PcActorCriticConfig {
693 actor: PcActorConfig {
694 input_size: 9,
695 hidden_layers: vec![LayerDef {
696 size: 18,
697 activation: Activation::Tanh,
698 }],
699 output_size: 9,
700 output_activation: Activation::Tanh,
701 alpha: 0.1,
702 tol: 0.01,
703 min_steps: 1,
704 max_steps: 20,
705 lr_weights: 0.01,
706 synchronous: true,
707 temperature: 1.0,
708 local_lambda: 1.0,
709 residual: false,
710 rezero_init: 0.001,
711 },
712 critic: MlpCriticConfig {
713 input_size: 27,
714 hidden_layers: vec![LayerDef {
715 size: 36,
716 activation: Activation::Tanh,
717 }],
718 output_activation: Activation::Linear,
719 lr: 0.005,
720 },
721 gamma: 0.95,
722 surprise_low: 0.02,
723 surprise_high: 0.15,
724 adaptive_surprise: false,
725 surprise_buffer_size: 100,
726 entropy_coeff: 0.01,
727 }
728 }
729
730 fn make_agent() -> PcActorCritic {
731 let agent: PcActorCritic = PcActorCritic::new(default_config(), 42).unwrap();
732 agent
733 }
734
735 fn make_trajectory(agent: &mut PcActorCritic) -> Vec<TrajectoryStep> {
736 let input = vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5];
737 let valid = vec![2, 7];
738 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
739 vec![TrajectoryStep {
740 input,
741 latent_concat: infer.latent_concat,
742 y_conv: infer.y_conv,
743 hidden_states: infer.hidden_states,
744 prediction_errors: infer.prediction_errors,
745 tanh_components: infer.tanh_components,
746 action,
747 valid_actions: valid,
748 reward: 1.0,
749 surprise_score: infer.surprise_score,
750 steps_used: infer.steps_used,
751 }]
752 }
753
754 #[test]
757 fn test_learn_empty_returns_zero_without_modifying_weights() {
758 let mut agent: PcActorCritic = make_agent();
759 let w_before = agent.actor.layers[0].weights.data.clone();
760 let cw_before = agent.critic.layers[0].weights.data.clone();
761 let loss = agent.learn(&[]);
762 assert_eq!(loss, 0.0);
763 assert_eq!(agent.actor.layers[0].weights.data, w_before);
764 assert_eq!(agent.critic.layers[0].weights.data, cw_before);
765 }
766
767 #[test]
768 fn test_learn_updates_actor_weights() {
769 let mut agent: PcActorCritic = make_agent();
770 let trajectory = make_trajectory(&mut agent);
771 let w_before = agent.actor.layers[0].weights.data.clone();
772 let _ = agent.learn(&trajectory);
773 assert_ne!(agent.actor.layers[0].weights.data, w_before);
774 }
775
776 #[test]
777 fn test_learn_updates_critic_weights() {
778 let mut agent: PcActorCritic = make_agent();
779 let trajectory = make_trajectory(&mut agent);
780 let w_before = agent.critic.layers[0].weights.data.clone();
781 let _ = agent.learn(&trajectory);
782 assert_ne!(agent.critic.layers[0].weights.data, w_before);
783 }
784
785 #[test]
786 fn test_learn_returns_finite_nonneg_loss() {
787 let mut agent: PcActorCritic = make_agent();
788 let trajectory = make_trajectory(&mut agent);
789 let loss = agent.learn(&trajectory);
790 assert!(loss.is_finite(), "Loss {loss} is not finite");
791 assert!(loss >= 0.0, "Loss {loss} is negative");
792 }
793
794 #[test]
795 fn test_learn_single_step_trajectory() {
796 let mut agent: PcActorCritic = make_agent();
797 let input = vec![0.5; 9];
798 let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
799 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
800 let trajectory = vec![TrajectoryStep {
801 input,
802 latent_concat: infer.latent_concat,
803 y_conv: infer.y_conv,
804 hidden_states: infer.hidden_states,
805 prediction_errors: infer.prediction_errors,
806 tanh_components: infer.tanh_components,
807 action,
808 valid_actions: valid,
809 reward: -1.0,
810 surprise_score: infer.surprise_score,
811 steps_used: infer.steps_used,
812 }];
813 let loss = agent.learn(&trajectory);
814 assert!(loss.is_finite());
815 }
816
817 #[test]
818 fn test_learn_multi_step_uses_stored_hidden_states() {
819 let mut agent: PcActorCritic = make_agent();
821 let inputs = [
822 vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5],
823 vec![0.5, 0.5, -1.0, 0.0, 1.0, -0.5, 0.0, -1.0, 0.5],
824 vec![-1.0, 0.0, 1.0, -0.5, 0.5, 0.0, 1.0, -1.0, -0.5],
825 ];
826 let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
827
828 let mut trajectory = Vec::new();
829 for (i, inp) in inputs.iter().enumerate() {
830 let (action, infer) = agent.act(inp, &valid, SelectionMode::Training);
831 trajectory.push(TrajectoryStep {
832 input: inp.clone(),
833 latent_concat: infer.latent_concat,
834 y_conv: infer.y_conv,
835 hidden_states: infer.hidden_states,
836 prediction_errors: infer.prediction_errors,
837 tanh_components: infer.tanh_components,
838 action,
839 valid_actions: valid.clone(),
840 reward: if i == 2 { 1.0 } else { 0.0 },
841 surprise_score: infer.surprise_score,
842 steps_used: infer.steps_used,
843 });
844 }
845
846 let loss = agent.learn(&trajectory);
847 assert!(
848 loss.is_finite(),
849 "Multi-step learn should produce finite loss"
850 );
851 assert!(loss >= 0.0);
852 }
853
854 #[test]
857 fn test_learn_continuous_nonterminal_uses_next_value() {
858 let mut agent: PcActorCritic = make_agent();
859 let input = vec![0.5; 9];
860 let next_input = vec![-0.5; 9];
861 let valid = vec![0, 1, 2];
862 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
863 let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
864
865 let loss = agent.learn_continuous(
867 &input,
868 &infer,
869 action,
870 &valid,
871 0.5,
872 &next_input,
873 &next_infer,
874 false,
875 );
876 assert!(loss.is_finite());
877 }
878
879 #[test]
880 fn test_learn_continuous_terminal_uses_reward_only() {
881 let mut agent: PcActorCritic = make_agent();
882 let input = vec![0.5; 9];
883 let next_input = vec![0.0; 9];
884 let valid = vec![0, 1, 2];
885 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
886 let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
887
888 let loss = agent.learn_continuous(
890 &input,
891 &infer,
892 action,
893 &valid,
894 1.0,
895 &next_input,
896 &next_infer,
897 true,
898 );
899 assert!(loss.is_finite());
900 }
901
902 #[test]
903 fn test_learn_continuous_terminal_and_nonterminal_produce_different_updates() {
904 let config = default_config();
906 let mut agent_term: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
907 let mut agent_nonterm: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
908
909 let input = vec![0.5; 9];
910 let next_input = vec![-0.5; 9];
911 let valid = vec![0, 1, 2];
912
913 let (action, infer) = agent_term.act(&input, &valid, SelectionMode::Training);
915 let (_, next_infer) = agent_term.act(&next_input, &valid, SelectionMode::Training);
916
917 let (action2, infer2) = agent_nonterm.act(&input, &valid, SelectionMode::Training);
919 let (_, next_infer2) = agent_nonterm.act(&next_input, &valid, SelectionMode::Training);
920
921 let loss_term = agent_term.learn_continuous(
923 &input,
924 &infer,
925 action,
926 &valid,
927 1.0,
928 &next_input,
929 &next_infer,
930 true,
931 );
932
933 let loss_nonterm = agent_nonterm.learn_continuous(
935 &input,
936 &infer2,
937 action2,
938 &valid,
939 1.0,
940 &next_input,
941 &next_infer2,
942 false,
943 );
944
945 assert!(
948 (loss_term - loss_nonterm).abs() > 1e-15,
949 "Terminal and non-terminal should produce different losses: {loss_term} vs {loss_nonterm}"
950 );
951 }
952
953 #[test]
954 fn test_learn_continuous_updates_actor() {
955 let mut agent: PcActorCritic = make_agent();
956 let input = vec![0.5; 9];
957 let next_input = vec![-0.5; 9];
958 let valid = vec![0, 1, 2];
959 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
960 let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
961 let w_before = agent.actor.layers[0].weights.data.clone();
962 let _ = agent.learn_continuous(
963 &input,
964 &infer,
965 action,
966 &valid,
967 1.0,
968 &next_input,
969 &next_infer,
970 false,
971 );
972 assert_ne!(agent.actor.layers[0].weights.data, w_before);
973 }
974
975 #[test]
978 fn test_surprise_scale_below_low() {
979 let agent: PcActorCritic = make_agent();
980 let scale = agent.surprise_scale(0.01); assert!((scale - 0.1).abs() < 1e-12, "Expected 0.1, got {scale}");
982 }
983
984 #[test]
985 fn test_surprise_scale_above_high() {
986 let agent: PcActorCritic = make_agent();
987 let scale = agent.surprise_scale(0.20); assert!((scale - 2.0).abs() < 1e-12, "Expected 2.0, got {scale}");
989 }
990
991 #[test]
992 fn test_surprise_scale_midpoint_in_range() {
993 let agent: PcActorCritic = make_agent();
994 let midpoint = (0.02 + 0.15) / 2.0;
995 let scale = agent.surprise_scale(midpoint);
996 assert!(
997 scale > 0.1 && scale < 2.0,
998 "Midpoint scale {scale} out of range"
999 );
1000 }
1001
1002 #[test]
1003 fn test_surprise_scale_monotone_increasing() {
1004 let agent: PcActorCritic = make_agent();
1005 let s1 = agent.surprise_scale(0.01);
1006 let s2 = agent.surprise_scale(0.05);
1007 let s3 = agent.surprise_scale(0.10);
1008 let s4 = agent.surprise_scale(0.20);
1009 assert!(s1 <= s2, "s1={s1} > s2={s2}");
1010 assert!(s2 <= s3, "s2={s2} > s3={s3}");
1011 assert!(s3 <= s4, "s3={s3} > s4={s4}");
1012 }
1013
1014 #[test]
1015 fn test_adaptive_surprise_recalibrates_thresholds_after_many_episodes() {
1016 let mut config = default_config();
1017 config.adaptive_surprise = true;
1018 let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1019
1020 for i in 0..15 {
1022 agent.push_surprise(0.1 + 0.02 * i as f64);
1023 }
1024
1025 let scale_low = agent.surprise_scale(0.0);
1032 assert!(
1033 (scale_low - 0.1).abs() < 1e-12,
1034 "Expected 0.1 below adaptive low: got {scale_low}"
1035 );
1036
1037 let scale_high = agent.surprise_scale(1.0);
1039 assert!(
1040 (scale_high - 2.0).abs() < 1e-12,
1041 "Expected 2.0 above adaptive high: got {scale_high}"
1042 );
1043
1044 let scale_mid = agent.surprise_scale(0.24);
1046 assert!(
1047 scale_mid > 0.1 && scale_mid < 2.0,
1048 "Expected interpolated value at mean, got {scale_mid}"
1049 );
1050 }
1051
1052 #[test]
1053 fn test_entropy_regularization_prevents_policy_collapse() {
1054 let mut config = default_config();
1057 config.entropy_coeff = 0.1; let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1059
1060 let input = vec![0.5; 9];
1061 let valid: Vec<usize> = (0..9).collect();
1062
1063 for _ in 0..20 {
1065 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
1066 let trajectory = vec![TrajectoryStep {
1067 input: input.clone(),
1068 latent_concat: infer.latent_concat,
1069 y_conv: infer.y_conv,
1070 hidden_states: infer.hidden_states,
1071 prediction_errors: infer.prediction_errors,
1072 tanh_components: infer.tanh_components,
1073 action,
1074 valid_actions: valid.clone(),
1075 reward: 1.0,
1076 surprise_score: infer.surprise_score,
1077 steps_used: infer.steps_used,
1078 }];
1079 let _ = agent.learn(&trajectory);
1080 }
1081
1082 let mut seen = std::collections::HashSet::new();
1084 for _ in 0..50 {
1085 let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1086 seen.insert(action);
1087 }
1088 assert!(
1089 seen.len() > 1,
1090 "Entropy regularization should prevent collapse to single action, but only saw {:?}",
1091 seen
1092 );
1093 }
1094
1095 #[test]
1098 fn test_act_returns_valid_action() {
1099 let mut agent: PcActorCritic = make_agent();
1100 let input = vec![0.5; 9];
1101 let valid = vec![1, 3, 5, 7];
1102 for _ in 0..20 {
1103 let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1104 assert!(valid.contains(&action), "Action {action} not in valid set");
1105 }
1106 }
1107
1108 #[test]
1109 #[should_panic]
1110 fn test_act_empty_valid_panics() {
1111 let mut agent: PcActorCritic = make_agent();
1112 let input = vec![0.5; 9];
1113 let _ = agent.act(&input, &[], SelectionMode::Training);
1114 }
1115
1116 #[test]
1119 fn test_learn_improves_policy_for_rewarded_action() {
1120 let config = PcActorCriticConfig {
1122 actor: PcActorConfig {
1123 input_size: 9,
1124 hidden_layers: vec![LayerDef {
1125 size: 18,
1126 activation: Activation::Tanh,
1127 }],
1128 output_size: 9,
1129 output_activation: Activation::Linear,
1130 alpha: 0.1,
1131 tol: 0.01,
1132 min_steps: 1,
1133 max_steps: 5,
1134 lr_weights: 0.01,
1135 synchronous: true,
1136 temperature: 1.0,
1137 local_lambda: 1.0,
1138 residual: false,
1139 rezero_init: 0.001,
1140 },
1141 critic: MlpCriticConfig {
1142 input_size: 27,
1143 hidden_layers: vec![LayerDef {
1144 size: 36,
1145 activation: Activation::Tanh,
1146 }],
1147 output_activation: Activation::Linear,
1148 lr: 0.005,
1149 },
1150 gamma: 0.99,
1151 surprise_low: 0.02,
1152 surprise_high: 0.15,
1153 adaptive_surprise: false,
1154 surprise_buffer_size: 100,
1155 entropy_coeff: 0.0, };
1157 let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1158
1159 let input = vec![0.0; 9];
1160 let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
1161 let target_action = 4; for _ in 0..200 {
1165 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1166 let trajectory = vec![TrajectoryStep {
1167 input: input.clone(),
1168 latent_concat: infer.latent_concat,
1169 y_conv: infer.y_conv,
1170 hidden_states: infer.hidden_states,
1171 prediction_errors: infer.prediction_errors,
1172 tanh_components: infer.tanh_components,
1173 action: target_action,
1174 valid_actions: valid.clone(),
1175 reward: 1.0,
1176 surprise_score: infer.surprise_score,
1177 steps_used: infer.steps_used,
1178 }];
1179 agent.learn(&trajectory);
1180 }
1181
1182 let (action, infer) = agent.act(&input, &valid, SelectionMode::Play);
1185
1186 let logit_4 = infer.y_conv[4];
1188 let max_other = valid
1189 .iter()
1190 .filter(|&&a| a != 4)
1191 .map(|&a| infer.y_conv[a])
1192 .fold(f64::NEG_INFINITY, f64::max);
1193
1194 eprintln!(
1195 "DIAGNOSTIC: action={action}, logit[4]={logit_4:.4}, max_other={max_other:.4}, \
1196 y_conv={:?}",
1197 infer
1198 .y_conv
1199 .iter()
1200 .map(|v| format!("{v:.3}"))
1201 .collect::<Vec<_>>()
1202 );
1203
1204 assert_eq!(
1205 action, target_action,
1206 "After 200 episodes rewarding action 4, agent should prefer it. Got action {action}"
1207 );
1208 }
1209
1210 #[test]
1213 fn test_new_returns_error_zero_temperature() {
1214 let mut config = default_config();
1215 config.actor.temperature = 0.0;
1216 let err = PcActorCritic::new(config, 42)
1217 .map(|_: PcActorCritic| ())
1218 .unwrap_err();
1219 assert!(format!("{err}").contains("temperature"));
1220 }
1221
1222 #[test]
1223 fn test_new_returns_error_zero_input_size() {
1224 let mut config = default_config();
1225 config.actor.input_size = 0;
1226 config.critic.input_size = 0;
1227 assert!(PcActorCritic::new(config, 42)
1228 .map(|_: PcActorCritic| ())
1229 .is_err());
1230 }
1231
1232 #[test]
1233 fn test_new_returns_error_zero_output_size() {
1234 let mut config = default_config();
1235 config.actor.output_size = 0;
1236 assert!(PcActorCritic::new(config, 42)
1237 .map(|_: PcActorCritic| ())
1238 .is_err());
1239 }
1240
1241 #[test]
1242 fn test_new_returns_error_negative_gamma() {
1243 let mut config = default_config();
1244 config.gamma = -0.1;
1245 let err = PcActorCritic::new(config, 42)
1246 .map(|_: PcActorCritic| ())
1247 .unwrap_err();
1248 assert!(format!("{err}").contains("gamma"));
1249 }
1250
1251 #[test]
1252 fn test_new_returns_error_surprise_buffer_size_zero() {
1253 let mut config = default_config();
1254 config.adaptive_surprise = true;
1255 config.surprise_buffer_size = 0;
1256 let result = PcActorCritic::new(config, 42).map(|_: PcActorCritic| ());
1257 assert!(result.is_err());
1258 let err = result.unwrap_err();
1259 assert!(
1260 format!("{err}").contains("surprise_buffer_size"),
1261 "Expected surprise_buffer_size error, got: {err}"
1262 );
1263 }
1264
1265 #[test]
1268 fn test_activation_cache_new_creates_empty() {
1269 let cache: ActivationCache = ActivationCache::new(3);
1270 assert_eq!(cache.batch_size(), 0);
1271 }
1272
1273 #[test]
1274 fn test_activation_cache_record_increments_batch_size() {
1275 let mut agent: PcActorCritic = make_agent();
1276 let input = vec![0.5; 9];
1277 let valid = vec![0, 1, 2];
1278 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1279
1280 let num_hidden = infer.hidden_states.len();
1281 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1282 cache.record(&infer.hidden_states);
1283 assert_eq!(cache.batch_size(), 1);
1284 }
1285
1286 #[test]
1287 fn test_activation_cache_record_multiple() {
1288 let mut agent: PcActorCritic = make_agent();
1289 let valid = vec![0, 1, 2];
1290 let init_input = vec![0.5; 9];
1291 let num_hidden = {
1292 let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1293 infer.hidden_states.len()
1294 };
1295
1296 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1297 for i in 0..5 {
1298 let input = vec![i as f64 * 0.1; 9];
1299 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1300 cache.record(&infer.hidden_states);
1301 }
1302 assert_eq!(cache.batch_size(), 5);
1303 }
1304
1305 #[test]
1306 fn test_activation_cache_recorded_values_match_hidden_states() {
1307 let mut agent: PcActorCritic = make_agent();
1308 let input = vec![0.5; 9];
1309 let valid = vec![0, 1, 2];
1310 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1311
1312 let num_hidden = infer.hidden_states.len();
1313 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1314 cache.record(&infer.hidden_states);
1315
1316 for (layer_idx, expected) in infer.hidden_states.iter().enumerate() {
1318 let layer_data = cache.layer(layer_idx);
1319 assert_eq!(layer_data.len(), 1);
1320 assert_eq!(layer_data[0], *expected);
1321 }
1322 }
1323
1324 #[test]
1327 fn test_activation_cache_layer_count() {
1328 let mut agent: PcActorCritic = make_agent();
1329 let input = vec![0.5; 9];
1330 let valid = vec![0, 1, 2];
1331 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1332
1333 let num_hidden = infer.hidden_states.len();
1334 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1335 cache.record(&infer.hidden_states);
1336
1337 assert_eq!(cache.num_layers(), num_hidden);
1338 }
1339
1340 #[test]
1341 fn test_activation_cache_layer_sample_count() {
1342 let mut agent: PcActorCritic = make_agent();
1343 let valid = vec![0, 1, 2];
1344 let init_input = vec![0.5; 9];
1345 let num_hidden = {
1346 let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1347 infer.hidden_states.len()
1348 };
1349
1350 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1351 for i in 0..10 {
1352 let input = vec![i as f64 * 0.1; 9];
1353 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1354 cache.record(&infer.hidden_states);
1355 }
1356
1357 for layer_idx in 0..num_hidden {
1358 assert_eq!(
1359 cache.layer(layer_idx).len(),
1360 10,
1361 "Layer {layer_idx} should have 10 samples"
1362 );
1363 }
1364 }
1365
1366 fn build_caches_for_agent(
1369 agent: &mut PcActorCritic,
1370 batch_size: usize,
1371 ) -> (ActivationCache, ActivationCache) {
1372 let num_actor_hidden = agent.config.actor.hidden_layers.len();
1373 let num_critic_hidden = agent.config.critic.hidden_layers.len();
1374 let mut actor_cache: ActivationCache = ActivationCache::new(num_actor_hidden);
1375 let mut critic_cache: ActivationCache = ActivationCache::new(num_critic_hidden);
1376 let valid: Vec<usize> = (0..agent.config.actor.output_size).collect();
1377 for i in 0..batch_size {
1378 let input: Vec<f64> = (0..agent.config.actor.input_size)
1379 .map(|j| ((i * 9 + j) as f64 * 0.1).sin())
1380 .collect();
1381 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1382 actor_cache.record(&infer.hidden_states);
1383 let mut critic_input = input;
1384 critic_input.extend_from_slice(&infer.latent_concat);
1385 let (_value, critic_hidden) = agent.critic.forward_with_hidden(&critic_input);
1386 critic_cache.record(&critic_hidden);
1387 }
1388 (actor_cache, critic_cache)
1389 }
1390
1391 #[test]
1392 fn test_agent_crossover_produces_valid_agent() {
1393 let config = default_config();
1394 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1395 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1396
1397 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1398 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1399
1400 let child: PcActorCritic = PcActorCritic::crossover(
1401 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1402 )
1403 .unwrap();
1404
1405 assert_eq!(
1406 child.config.actor.hidden_layers.len(),
1407 agent_a.config.actor.hidden_layers.len()
1408 );
1409 }
1410
1411 #[test]
1412 fn test_agent_crossover_actor_weights_differ() {
1413 let config = default_config();
1414 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1415 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1416
1417 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1418 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1419
1420 let child: PcActorCritic = PcActorCritic::crossover(
1421 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1422 )
1423 .unwrap();
1424
1425 assert_ne!(
1426 child.actor.layers[0].weights.data,
1427 agent_a.actor.layers[0].weights.data
1428 );
1429 assert_ne!(
1430 child.actor.layers[0].weights.data,
1431 agent_b.actor.layers[0].weights.data
1432 );
1433 }
1434
1435 #[test]
1436 fn test_agent_crossover_critic_weights_differ() {
1437 let config = default_config();
1438 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1439 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1440
1441 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1442 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1443
1444 let child: PcActorCritic = PcActorCritic::crossover(
1445 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1446 )
1447 .unwrap();
1448
1449 assert_ne!(
1450 child.critic.layers[0].weights.data,
1451 agent_a.critic.layers[0].weights.data
1452 );
1453 assert_ne!(
1454 child.critic.layers[0].weights.data,
1455 agent_b.critic.layers[0].weights.data
1456 );
1457 }
1458
1459 #[test]
1462 fn test_agent_crossover_child_can_infer() {
1463 let config = default_config();
1464 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1465 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1466
1467 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1468 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1469
1470 let mut child: PcActorCritic = PcActorCritic::crossover(
1471 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1472 )
1473 .unwrap();
1474
1475 let input = vec![0.5; 9];
1476 let valid = vec![0, 1, 2, 3, 4];
1477 let (action, _) = child.act(&input, &valid, SelectionMode::Training);
1478 assert!(valid.contains(&action), "Action {action} not in valid set");
1479 }
1480
1481 #[test]
1482 fn test_agent_crossover_child_can_learn() {
1483 let config = default_config();
1484 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1485 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1486
1487 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1488 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1489
1490 let mut child: PcActorCritic = PcActorCritic::crossover(
1491 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1492 )
1493 .unwrap();
1494
1495 let trajectory = make_trajectory(&mut child);
1496 let loss = child.learn(&trajectory);
1497 assert!(loss.is_finite(), "Child learn loss not finite: {loss}");
1498 }
1499
1500 #[test]
1501 fn test_agent_crossover_mismatched_batch_size_error() {
1502 let config = default_config();
1503 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1504 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1505
1506 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1507 let (ac_b, _cc_b) = build_caches_for_agent(&mut agent_b, 30); let (_, cc_b_match) = build_caches_for_agent(&mut agent_b, 50);
1509
1510 let result = PcActorCritic::crossover(
1512 &agent_a,
1513 &agent_b,
1514 &ac_a,
1515 &ac_b,
1516 &cc_a,
1517 &cc_b_match,
1518 0.5,
1519 config,
1520 99,
1521 );
1522 assert!(result.is_err(), "Mismatched actor batch sizes should error");
1523 }
1524
1525 #[test]
1528 fn test_agent_crossover_with_separate_critic_caches() {
1529 let config = default_config();
1530 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1531 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1532
1533 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1534 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1535
1536 let child: PcActorCritic = PcActorCritic::crossover(
1537 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1538 )
1539 .unwrap();
1540
1541 assert_eq!(child.critic.layers.len(), agent_a.critic.layers.len());
1542 }
1543
1544 #[test]
1545 fn test_agent_crossover_critic_uses_own_caches() {
1546 let config = default_config();
1547 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1548 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1549
1550 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1551 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1552
1553 let child: PcActorCritic = PcActorCritic::crossover(
1554 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1555 )
1556 .unwrap();
1557
1558 assert_ne!(
1559 child.critic.layers[0].weights.data,
1560 agent_a.critic.layers[0].weights.data
1561 );
1562 assert_ne!(
1563 child.critic.layers[0].weights.data,
1564 agent_b.critic.layers[0].weights.data
1565 );
1566 }
1567
1568 #[test]
1569 fn test_agent_crossover_mismatched_critic_batch_error() {
1570 let config = default_config();
1571 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1572 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1573
1574 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1575 let (ac_b, _) = build_caches_for_agent(&mut agent_b, 50);
1576 let (_, cc_b_small) = build_caches_for_agent(&mut agent_b, 30);
1578
1579 let result = PcActorCritic::crossover(
1580 &agent_a,
1581 &agent_b,
1582 &ac_a,
1583 &ac_b,
1584 &cc_a,
1585 &cc_b_small,
1586 0.5,
1587 config,
1588 99,
1589 );
1590 assert!(
1591 result.is_err(),
1592 "Mismatched critic batch sizes should error"
1593 );
1594 }
1595
1596 #[test]
1599 fn test_activation_cache_accessible_from_crate() {
1600 let _cache: crate::pc_actor_critic::ActivationCache = ActivationCache::new(1);
1602 }
1603
1604 #[test]
1605 fn test_cca_neuron_alignment_accessible_from_crate() {
1606 use crate::linalg::cpu::CpuLinAlg;
1608 use crate::linalg::LinAlg;
1609 let mat = CpuLinAlg::zeros_mat(10, 3);
1610 let _perm = crate::matrix::cca_neuron_alignment::<CpuLinAlg>(&mat, &mat).unwrap();
1611 }
1612}