1use std::collections::VecDeque;
15
16use rand::rngs::StdRng;
17use serde::{Deserialize, Serialize};
18
19use crate::error::PcError;
20use crate::linalg::cpu::CpuLinAlg;
21use crate::linalg::LinAlg;
22use crate::mlp_critic::{MlpCritic, MlpCriticConfig};
23use crate::pc_actor::{InferResult, PcActor, PcActorConfig, SelectionMode};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct PcActorCriticConfig {
63 pub actor: PcActorConfig,
65 pub critic: MlpCriticConfig,
67 pub gamma: f64,
69 pub surprise_low: f64,
71 pub surprise_high: f64,
73 pub adaptive_surprise: bool,
75 pub entropy_coeff: f64,
77}
78
79#[derive(Debug, Clone)]
83pub struct TrajectoryStep<L: LinAlg = CpuLinAlg> {
84 pub input: L::Vector,
86 pub latent_concat: L::Vector,
88 pub y_conv: L::Vector,
90 pub hidden_states: Vec<L::Vector>,
92 pub prediction_errors: Vec<L::Vector>,
94 pub tanh_components: Vec<Option<L::Vector>>,
96 pub action: usize,
98 pub valid_actions: Vec<usize>,
100 pub reward: f64,
102 pub surprise_score: f64,
104 pub steps_used: usize,
106}
107
108#[derive(Debug, Clone)]
127pub struct ActivationCache<L: LinAlg = CpuLinAlg> {
128 layers: Vec<Vec<L::Vector>>,
130}
131
132impl<L: LinAlg> ActivationCache<L> {
133 pub fn new(num_layers: usize) -> Self {
139 Self {
140 layers: (0..num_layers).map(|_| Vec::new()).collect(),
141 }
142 }
143
144 pub fn batch_size(&self) -> usize {
146 self.layers.first().map_or(0, |l| l.len())
147 }
148
149 pub fn num_layers(&self) -> usize {
151 self.layers.len()
152 }
153
154 pub fn record(&mut self, hidden_states: &[L::Vector]) {
160 for (layer, state) in self.layers.iter_mut().zip(hidden_states.iter()) {
161 layer.push(state.clone());
162 }
163 }
164
165 pub fn layer(&self, layer_idx: usize) -> &[L::Vector] {
171 &self.layers[layer_idx]
172 }
173}
174
175#[derive(Debug)]
182pub struct PcActorCritic<L: LinAlg = CpuLinAlg> {
183 pub(crate) actor: PcActor<L>,
185 pub(crate) critic: MlpCritic<L>,
187 pub config: PcActorCriticConfig,
189 rng: StdRng,
191 surprise_buffer: VecDeque<f64>,
193}
194
195impl<L: LinAlg> PcActorCritic<L> {
196 pub fn new(config: PcActorCriticConfig, seed: u64) -> Result<Self, PcError> {
207 if !(0.0..=1.0).contains(&config.gamma) {
208 return Err(PcError::ConfigValidation(format!(
209 "gamma must be in [0.0, 1.0], got {}",
210 config.gamma
211 )));
212 }
213
214 use rand::SeedableRng;
215 let mut rng = StdRng::seed_from_u64(seed);
216 let actor = PcActor::<L>::new(config.actor.clone(), &mut rng)?;
217 let critic = MlpCritic::<L>::new(config.critic.clone(), &mut rng)?;
218 Ok(Self {
219 actor,
220 critic,
221 config,
222 rng,
223 surprise_buffer: VecDeque::new(),
224 })
225 }
226
227 #[allow(clippy::too_many_arguments)]
247 pub fn crossover(
248 parent_a: &PcActorCritic<L>,
249 parent_b: &PcActorCritic<L>,
250 actor_cache_a: &ActivationCache<L>,
251 actor_cache_b: &ActivationCache<L>,
252 critic_cache_a: &ActivationCache<L>,
253 critic_cache_b: &ActivationCache<L>,
254 alpha: f64,
255 child_config: PcActorCriticConfig,
256 seed: u64,
257 ) -> Result<Self, PcError> {
258 if actor_cache_a.batch_size() != actor_cache_b.batch_size() {
260 return Err(PcError::DimensionMismatch {
261 expected: actor_cache_a.batch_size(),
262 got: actor_cache_b.batch_size(),
263 context: "actor activation cache batch sizes must match for crossover",
264 });
265 }
266 if critic_cache_a.batch_size() != critic_cache_b.batch_size() {
268 return Err(PcError::DimensionMismatch {
269 expected: critic_cache_a.batch_size(),
270 got: critic_cache_b.batch_size(),
271 context: "critic activation cache batch sizes must match for crossover",
272 });
273 }
274
275 let actor_cache_mats_a = cache_to_matrices::<L>(actor_cache_a);
277 let actor_cache_mats_b = cache_to_matrices::<L>(actor_cache_b);
278 let critic_cache_mats_a = cache_to_matrices::<L>(critic_cache_a);
279 let critic_cache_mats_b = cache_to_matrices::<L>(critic_cache_b);
280
281 use rand::SeedableRng;
282 let mut rng = StdRng::seed_from_u64(seed);
283
284 let actor = PcActor::<L>::crossover(
286 &parent_a.actor,
287 &parent_b.actor,
288 &actor_cache_mats_a,
289 &actor_cache_mats_b,
290 alpha,
291 child_config.actor.clone(),
292 &mut rng,
293 )?;
294
295 let critic = MlpCritic::<L>::crossover(
297 &parent_a.critic,
298 &parent_b.critic,
299 &critic_cache_mats_a,
300 &critic_cache_mats_b,
301 alpha,
302 child_config.critic.clone(),
303 &mut rng,
304 )?;
305
306 Ok(Self {
307 actor,
308 critic,
309 config: child_config,
310 rng,
311 surprise_buffer: VecDeque::new(),
312 })
313 }
314
315 pub fn from_parts(
324 config: PcActorCriticConfig,
325 actor: PcActor<L>,
326 critic: MlpCritic<L>,
327 rng: StdRng,
328 ) -> Self {
329 Self {
330 actor,
331 critic,
332 config,
333 rng,
334 surprise_buffer: VecDeque::new(),
335 }
336 }
337
338 pub fn infer(&self, input: &[f64]) -> InferResult<L> {
351 self.actor.infer(input)
352 }
353
354 pub fn act(
369 &mut self,
370 input: &[f64],
371 valid_actions: &[usize],
372 mode: SelectionMode,
373 ) -> (usize, InferResult<L>) {
374 let infer_result = self.actor.infer(input);
375 let action =
376 self.actor
377 .select_action(&infer_result.y_conv, valid_actions, mode, &mut self.rng);
378 (action, infer_result)
379 }
380
381 pub fn learn(&mut self, trajectory: &[TrajectoryStep<L>]) -> f64 {
394 if trajectory.is_empty() {
395 return 0.0;
396 }
397
398 let n = trajectory.len();
399
400 let mut returns = vec![0.0; n];
402 returns[n - 1] = trajectory[n - 1].reward;
403 for t in (0..n - 1).rev() {
404 returns[t] = trajectory[t].reward + self.config.gamma * returns[t + 1];
405 }
406
407 let mut total_loss = 0.0;
408
409 for (t, step) in trajectory.iter().enumerate() {
410 let input_vec = L::vec_to_vec(&step.input);
412 let latent_vec = L::vec_to_vec(&step.latent_concat);
413 let mut critic_input = input_vec.clone();
414 critic_input.extend_from_slice(&latent_vec);
415
416 let value = self.critic.forward(&critic_input);
418 let advantage = returns[t] - value;
419
420 let loss = self.critic.update(&critic_input, returns[t]);
422 total_loss += loss;
423
424 let y_conv_vec = L::vec_to_vec(&step.y_conv);
426 let scaled: Vec<f64> = y_conv_vec
427 .iter()
428 .map(|&v| v / self.actor.config.temperature)
429 .collect();
430 let scaled_l = L::vec_from_slice(&scaled);
431 let pi_l = L::softmax_masked(&scaled_l, &step.valid_actions);
432 let pi = L::vec_to_vec(&pi_l);
433
434 let mut delta = vec![0.0; pi.len()];
435 for &i in &step.valid_actions {
436 delta[i] = pi[i];
437 }
438 delta[step.action] -= 1.0;
439
440 for &i in &step.valid_actions {
442 delta[i] *= advantage;
443 }
444
445 for &i in &step.valid_actions {
447 let log_pi = (pi[i].max(1e-10)).ln();
448 delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
449 }
450
451 let s_scale = self.surprise_scale(step.surprise_score);
453
454 let stored_infer = InferResult {
455 y_conv: step.y_conv.clone(),
456 latent_concat: step.latent_concat.clone(),
457 hidden_states: step.hidden_states.clone(),
458 prediction_errors: step.prediction_errors.clone(),
459 surprise_score: step.surprise_score,
460 steps_used: step.steps_used,
461 converged: false,
462 tanh_components: step.tanh_components.clone(),
463 };
464 self.actor
465 .update_weights(&delta, &stored_infer, &input_vec, s_scale);
466
467 if self.config.adaptive_surprise {
469 self.push_surprise(step.surprise_score);
470 }
471 }
472
473 total_loss / n as f64
474 }
475
476 #[allow(clippy::too_many_arguments)]
493 pub fn learn_continuous(
494 &mut self,
495 input: &[f64],
496 infer: &InferResult<L>,
497 action: usize,
498 valid_actions: &[usize],
499 reward: f64,
500 next_input: &[f64],
501 next_infer: &InferResult<L>,
502 terminal: bool,
503 ) -> f64 {
504 let latent_vec = L::vec_to_vec(&infer.latent_concat);
506 let mut critic_input = input.to_vec();
507 critic_input.extend_from_slice(&latent_vec);
508
509 let next_latent_vec = L::vec_to_vec(&next_infer.latent_concat);
510 let mut next_critic_input = next_input.to_vec();
511 next_critic_input.extend_from_slice(&next_latent_vec);
512
513 let v_s = self.critic.forward(&critic_input);
514 let v_next = if terminal {
515 0.0
516 } else {
517 self.critic.forward(&next_critic_input)
518 };
519
520 let target = reward
521 + if terminal {
522 0.0
523 } else {
524 self.config.gamma * v_next
525 };
526 let td_error = target - v_s;
527
528 let loss = self.critic.update(&critic_input, target);
530
531 let y_conv_vec = L::vec_to_vec(&infer.y_conv);
533 let scaled: Vec<f64> = y_conv_vec
534 .iter()
535 .map(|&v| v / self.actor.config.temperature)
536 .collect();
537 let scaled_l = L::vec_from_slice(&scaled);
538 let pi_l = L::softmax_masked(&scaled_l, valid_actions);
539 let pi = L::vec_to_vec(&pi_l);
540
541 let mut delta = vec![0.0; pi.len()];
542 for &i in valid_actions {
543 delta[i] = pi[i];
544 }
545 delta[action] -= 1.0;
546
547 for &i in valid_actions {
548 delta[i] *= td_error;
549 }
550
551 for &i in valid_actions {
553 let log_pi = (pi[i].max(1e-10)).ln();
554 delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
555 }
556
557 let s_scale = self.surprise_scale(infer.surprise_score);
558 self.actor.update_weights(&delta, infer, input, s_scale);
559
560 if self.config.adaptive_surprise {
561 self.push_surprise(infer.surprise_score);
562 }
563
564 loss
565 }
566
567 pub fn surprise_scale(&self, surprise: f64) -> f64 {
576 let (low, high) = if self.config.adaptive_surprise && self.surprise_buffer.len() >= 10 {
577 let mean = self.surprise_buffer.iter().sum::<f64>() / self.surprise_buffer.len() as f64;
578 let variance = self
579 .surprise_buffer
580 .iter()
581 .map(|&s| (s - mean) * (s - mean))
582 .sum::<f64>()
583 / self.surprise_buffer.len() as f64;
584 let std = variance.sqrt();
585 let lo = (mean - 0.5 * std).max(0.0);
586 let hi = mean + 1.5 * std;
587 (lo, hi)
588 } else {
589 (self.config.surprise_low, self.config.surprise_high)
590 };
591
592 if surprise <= low {
593 0.1
594 } else if surprise >= high {
595 2.0
596 } else {
597 let t = (surprise - low) / (high - low);
599 0.1 + t * (2.0 - 0.1)
600 }
601 }
602
603 fn push_surprise(&mut self, surprise: f64) {
605 if self.surprise_buffer.len() >= 100 {
606 self.surprise_buffer.pop_front();
607 }
608 self.surprise_buffer.push_back(surprise);
609 }
610}
611
612fn cache_to_matrices<L: LinAlg>(cache: &ActivationCache<L>) -> Vec<L::Matrix> {
615 let num_layers = cache.num_layers();
616 let batch_size = cache.batch_size();
617 let mut matrices = Vec::with_capacity(num_layers);
618
619 for layer_idx in 0..num_layers {
620 let samples = cache.layer(layer_idx);
621 if samples.is_empty() {
622 matrices.push(L::zeros_mat(0, 0));
623 continue;
624 }
625 let n_neurons = L::vec_len(&samples[0]);
626 let mut mat = L::zeros_mat(batch_size, n_neurons);
627 for (r, sample) in samples.iter().enumerate() {
628 for c in 0..n_neurons {
629 L::mat_set(&mut mat, r, c, L::vec_get(sample, c));
630 }
631 }
632 matrices.push(mat);
633 }
634
635 matrices
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use crate::activation::Activation;
642 use crate::layer::LayerDef;
643 use crate::pc_actor::SelectionMode;
644
645 fn default_config() -> PcActorCriticConfig {
646 PcActorCriticConfig {
647 actor: PcActorConfig {
648 input_size: 9,
649 hidden_layers: vec![LayerDef {
650 size: 18,
651 activation: Activation::Tanh,
652 }],
653 output_size: 9,
654 output_activation: Activation::Tanh,
655 alpha: 0.1,
656 tol: 0.01,
657 min_steps: 1,
658 max_steps: 20,
659 lr_weights: 0.01,
660 synchronous: true,
661 temperature: 1.0,
662 local_lambda: 1.0,
663 residual: false,
664 rezero_init: 0.001,
665 },
666 critic: MlpCriticConfig {
667 input_size: 27,
668 hidden_layers: vec![LayerDef {
669 size: 36,
670 activation: Activation::Tanh,
671 }],
672 output_activation: Activation::Linear,
673 lr: 0.005,
674 },
675 gamma: 0.95,
676 surprise_low: 0.02,
677 surprise_high: 0.15,
678 adaptive_surprise: false,
679 entropy_coeff: 0.01,
680 }
681 }
682
683 fn make_agent() -> PcActorCritic {
684 let agent: PcActorCritic = PcActorCritic::new(default_config(), 42).unwrap();
685 agent
686 }
687
688 fn make_trajectory(agent: &mut PcActorCritic) -> Vec<TrajectoryStep> {
689 let input = vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5];
690 let valid = vec![2, 7];
691 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
692 vec![TrajectoryStep {
693 input,
694 latent_concat: infer.latent_concat,
695 y_conv: infer.y_conv,
696 hidden_states: infer.hidden_states,
697 prediction_errors: infer.prediction_errors,
698 tanh_components: infer.tanh_components,
699 action,
700 valid_actions: valid,
701 reward: 1.0,
702 surprise_score: infer.surprise_score,
703 steps_used: infer.steps_used,
704 }]
705 }
706
707 #[test]
710 fn test_learn_empty_returns_zero_without_modifying_weights() {
711 let mut agent: PcActorCritic = make_agent();
712 let w_before = agent.actor.layers[0].weights.data.clone();
713 let cw_before = agent.critic.layers[0].weights.data.clone();
714 let loss = agent.learn(&[]);
715 assert_eq!(loss, 0.0);
716 assert_eq!(agent.actor.layers[0].weights.data, w_before);
717 assert_eq!(agent.critic.layers[0].weights.data, cw_before);
718 }
719
720 #[test]
721 fn test_learn_updates_actor_weights() {
722 let mut agent: PcActorCritic = make_agent();
723 let trajectory = make_trajectory(&mut agent);
724 let w_before = agent.actor.layers[0].weights.data.clone();
725 let _ = agent.learn(&trajectory);
726 assert_ne!(agent.actor.layers[0].weights.data, w_before);
727 }
728
729 #[test]
730 fn test_learn_updates_critic_weights() {
731 let mut agent: PcActorCritic = make_agent();
732 let trajectory = make_trajectory(&mut agent);
733 let w_before = agent.critic.layers[0].weights.data.clone();
734 let _ = agent.learn(&trajectory);
735 assert_ne!(agent.critic.layers[0].weights.data, w_before);
736 }
737
738 #[test]
739 fn test_learn_returns_finite_nonneg_loss() {
740 let mut agent: PcActorCritic = make_agent();
741 let trajectory = make_trajectory(&mut agent);
742 let loss = agent.learn(&trajectory);
743 assert!(loss.is_finite(), "Loss {loss} is not finite");
744 assert!(loss >= 0.0, "Loss {loss} is negative");
745 }
746
747 #[test]
748 fn test_learn_single_step_trajectory() {
749 let mut agent: PcActorCritic = make_agent();
750 let input = vec![0.5; 9];
751 let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
752 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
753 let trajectory = vec![TrajectoryStep {
754 input,
755 latent_concat: infer.latent_concat,
756 y_conv: infer.y_conv,
757 hidden_states: infer.hidden_states,
758 prediction_errors: infer.prediction_errors,
759 tanh_components: infer.tanh_components,
760 action,
761 valid_actions: valid,
762 reward: -1.0,
763 surprise_score: infer.surprise_score,
764 steps_used: infer.steps_used,
765 }];
766 let loss = agent.learn(&trajectory);
767 assert!(loss.is_finite());
768 }
769
770 #[test]
771 fn test_learn_multi_step_uses_stored_hidden_states() {
772 let mut agent: PcActorCritic = make_agent();
774 let inputs = [
775 vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5],
776 vec![0.5, 0.5, -1.0, 0.0, 1.0, -0.5, 0.0, -1.0, 0.5],
777 vec![-1.0, 0.0, 1.0, -0.5, 0.5, 0.0, 1.0, -1.0, -0.5],
778 ];
779 let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
780
781 let mut trajectory = Vec::new();
782 for (i, inp) in inputs.iter().enumerate() {
783 let (action, infer) = agent.act(inp, &valid, SelectionMode::Training);
784 trajectory.push(TrajectoryStep {
785 input: inp.clone(),
786 latent_concat: infer.latent_concat,
787 y_conv: infer.y_conv,
788 hidden_states: infer.hidden_states,
789 prediction_errors: infer.prediction_errors,
790 tanh_components: infer.tanh_components,
791 action,
792 valid_actions: valid.clone(),
793 reward: if i == 2 { 1.0 } else { 0.0 },
794 surprise_score: infer.surprise_score,
795 steps_used: infer.steps_used,
796 });
797 }
798
799 let loss = agent.learn(&trajectory);
800 assert!(
801 loss.is_finite(),
802 "Multi-step learn should produce finite loss"
803 );
804 assert!(loss >= 0.0);
805 }
806
807 #[test]
810 fn test_learn_continuous_nonterminal_uses_next_value() {
811 let mut agent: PcActorCritic = make_agent();
812 let input = vec![0.5; 9];
813 let next_input = vec![-0.5; 9];
814 let valid = vec![0, 1, 2];
815 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
816 let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
817
818 let loss = agent.learn_continuous(
820 &input,
821 &infer,
822 action,
823 &valid,
824 0.5,
825 &next_input,
826 &next_infer,
827 false,
828 );
829 assert!(loss.is_finite());
830 }
831
832 #[test]
833 fn test_learn_continuous_terminal_uses_reward_only() {
834 let mut agent: PcActorCritic = make_agent();
835 let input = vec![0.5; 9];
836 let next_input = vec![0.0; 9];
837 let valid = vec![0, 1, 2];
838 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
839 let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
840
841 let loss = agent.learn_continuous(
843 &input,
844 &infer,
845 action,
846 &valid,
847 1.0,
848 &next_input,
849 &next_infer,
850 true,
851 );
852 assert!(loss.is_finite());
853 }
854
855 #[test]
856 fn test_learn_continuous_terminal_and_nonterminal_produce_different_updates() {
857 let config = default_config();
859 let mut agent_term: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
860 let mut agent_nonterm: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
861
862 let input = vec![0.5; 9];
863 let next_input = vec![-0.5; 9];
864 let valid = vec![0, 1, 2];
865
866 let (action, infer) = agent_term.act(&input, &valid, SelectionMode::Training);
868 let (_, next_infer) = agent_term.act(&next_input, &valid, SelectionMode::Training);
869
870 let (action2, infer2) = agent_nonterm.act(&input, &valid, SelectionMode::Training);
872 let (_, next_infer2) = agent_nonterm.act(&next_input, &valid, SelectionMode::Training);
873
874 let loss_term = agent_term.learn_continuous(
876 &input,
877 &infer,
878 action,
879 &valid,
880 1.0,
881 &next_input,
882 &next_infer,
883 true,
884 );
885
886 let loss_nonterm = agent_nonterm.learn_continuous(
888 &input,
889 &infer2,
890 action2,
891 &valid,
892 1.0,
893 &next_input,
894 &next_infer2,
895 false,
896 );
897
898 assert!(
901 (loss_term - loss_nonterm).abs() > 1e-15,
902 "Terminal and non-terminal should produce different losses: {loss_term} vs {loss_nonterm}"
903 );
904 }
905
906 #[test]
907 fn test_learn_continuous_updates_actor() {
908 let mut agent: PcActorCritic = make_agent();
909 let input = vec![0.5; 9];
910 let next_input = vec![-0.5; 9];
911 let valid = vec![0, 1, 2];
912 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
913 let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
914 let w_before = agent.actor.layers[0].weights.data.clone();
915 let _ = agent.learn_continuous(
916 &input,
917 &infer,
918 action,
919 &valid,
920 1.0,
921 &next_input,
922 &next_infer,
923 false,
924 );
925 assert_ne!(agent.actor.layers[0].weights.data, w_before);
926 }
927
928 #[test]
931 fn test_surprise_scale_below_low() {
932 let agent: PcActorCritic = make_agent();
933 let scale = agent.surprise_scale(0.01); assert!((scale - 0.1).abs() < 1e-12, "Expected 0.1, got {scale}");
935 }
936
937 #[test]
938 fn test_surprise_scale_above_high() {
939 let agent: PcActorCritic = make_agent();
940 let scale = agent.surprise_scale(0.20); assert!((scale - 2.0).abs() < 1e-12, "Expected 2.0, got {scale}");
942 }
943
944 #[test]
945 fn test_surprise_scale_midpoint_in_range() {
946 let agent: PcActorCritic = make_agent();
947 let midpoint = (0.02 + 0.15) / 2.0;
948 let scale = agent.surprise_scale(midpoint);
949 assert!(
950 scale > 0.1 && scale < 2.0,
951 "Midpoint scale {scale} out of range"
952 );
953 }
954
955 #[test]
956 fn test_surprise_scale_monotone_increasing() {
957 let agent: PcActorCritic = make_agent();
958 let s1 = agent.surprise_scale(0.01);
959 let s2 = agent.surprise_scale(0.05);
960 let s3 = agent.surprise_scale(0.10);
961 let s4 = agent.surprise_scale(0.20);
962 assert!(s1 <= s2, "s1={s1} > s2={s2}");
963 assert!(s2 <= s3, "s2={s2} > s3={s3}");
964 assert!(s3 <= s4, "s3={s3} > s4={s4}");
965 }
966
967 #[test]
968 fn test_adaptive_surprise_recalibrates_thresholds_after_many_episodes() {
969 let mut config = default_config();
970 config.adaptive_surprise = true;
971 let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
972
973 for i in 0..15 {
975 agent.push_surprise(0.1 + 0.02 * i as f64);
976 }
977
978 let scale_low = agent.surprise_scale(0.0);
985 assert!(
986 (scale_low - 0.1).abs() < 1e-12,
987 "Expected 0.1 below adaptive low: got {scale_low}"
988 );
989
990 let scale_high = agent.surprise_scale(1.0);
992 assert!(
993 (scale_high - 2.0).abs() < 1e-12,
994 "Expected 2.0 above adaptive high: got {scale_high}"
995 );
996
997 let scale_mid = agent.surprise_scale(0.24);
999 assert!(
1000 scale_mid > 0.1 && scale_mid < 2.0,
1001 "Expected interpolated value at mean, got {scale_mid}"
1002 );
1003 }
1004
1005 #[test]
1006 fn test_entropy_regularization_prevents_policy_collapse() {
1007 let mut config = default_config();
1010 config.entropy_coeff = 0.1; let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1012
1013 let input = vec![0.5; 9];
1014 let valid: Vec<usize> = (0..9).collect();
1015
1016 for _ in 0..20 {
1018 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
1019 let trajectory = vec![TrajectoryStep {
1020 input: input.clone(),
1021 latent_concat: infer.latent_concat,
1022 y_conv: infer.y_conv,
1023 hidden_states: infer.hidden_states,
1024 prediction_errors: infer.prediction_errors,
1025 tanh_components: infer.tanh_components,
1026 action,
1027 valid_actions: valid.clone(),
1028 reward: 1.0,
1029 surprise_score: infer.surprise_score,
1030 steps_used: infer.steps_used,
1031 }];
1032 let _ = agent.learn(&trajectory);
1033 }
1034
1035 let mut seen = std::collections::HashSet::new();
1037 for _ in 0..50 {
1038 let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1039 seen.insert(action);
1040 }
1041 assert!(
1042 seen.len() > 1,
1043 "Entropy regularization should prevent collapse to single action, but only saw {:?}",
1044 seen
1045 );
1046 }
1047
1048 #[test]
1051 fn test_act_returns_valid_action() {
1052 let mut agent: PcActorCritic = make_agent();
1053 let input = vec![0.5; 9];
1054 let valid = vec![1, 3, 5, 7];
1055 for _ in 0..20 {
1056 let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1057 assert!(valid.contains(&action), "Action {action} not in valid set");
1058 }
1059 }
1060
1061 #[test]
1062 #[should_panic]
1063 fn test_act_empty_valid_panics() {
1064 let mut agent: PcActorCritic = make_agent();
1065 let input = vec![0.5; 9];
1066 let _ = agent.act(&input, &[], SelectionMode::Training);
1067 }
1068
1069 #[test]
1072 fn test_learn_improves_policy_for_rewarded_action() {
1073 let config = PcActorCriticConfig {
1075 actor: PcActorConfig {
1076 input_size: 9,
1077 hidden_layers: vec![LayerDef {
1078 size: 18,
1079 activation: Activation::Tanh,
1080 }],
1081 output_size: 9,
1082 output_activation: Activation::Linear,
1083 alpha: 0.1,
1084 tol: 0.01,
1085 min_steps: 1,
1086 max_steps: 5,
1087 lr_weights: 0.01,
1088 synchronous: true,
1089 temperature: 1.0,
1090 local_lambda: 1.0,
1091 residual: false,
1092 rezero_init: 0.001,
1093 },
1094 critic: MlpCriticConfig {
1095 input_size: 27,
1096 hidden_layers: vec![LayerDef {
1097 size: 36,
1098 activation: Activation::Tanh,
1099 }],
1100 output_activation: Activation::Linear,
1101 lr: 0.005,
1102 },
1103 gamma: 0.99,
1104 surprise_low: 0.02,
1105 surprise_high: 0.15,
1106 adaptive_surprise: false,
1107 entropy_coeff: 0.0, };
1109 let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1110
1111 let input = vec![0.0; 9];
1112 let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
1113 let target_action = 4; for _ in 0..200 {
1117 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1118 let trajectory = vec![TrajectoryStep {
1119 input: input.clone(),
1120 latent_concat: infer.latent_concat,
1121 y_conv: infer.y_conv,
1122 hidden_states: infer.hidden_states,
1123 prediction_errors: infer.prediction_errors,
1124 tanh_components: infer.tanh_components,
1125 action: target_action,
1126 valid_actions: valid.clone(),
1127 reward: 1.0,
1128 surprise_score: infer.surprise_score,
1129 steps_used: infer.steps_used,
1130 }];
1131 agent.learn(&trajectory);
1132 }
1133
1134 let (action, infer) = agent.act(&input, &valid, SelectionMode::Play);
1137
1138 let logit_4 = infer.y_conv[4];
1140 let max_other = valid
1141 .iter()
1142 .filter(|&&a| a != 4)
1143 .map(|&a| infer.y_conv[a])
1144 .fold(f64::NEG_INFINITY, f64::max);
1145
1146 eprintln!(
1147 "DIAGNOSTIC: action={action}, logit[4]={logit_4:.4}, max_other={max_other:.4}, \
1148 y_conv={:?}",
1149 infer
1150 .y_conv
1151 .iter()
1152 .map(|v| format!("{v:.3}"))
1153 .collect::<Vec<_>>()
1154 );
1155
1156 assert_eq!(
1157 action, target_action,
1158 "After 200 episodes rewarding action 4, agent should prefer it. Got action {action}"
1159 );
1160 }
1161
1162 #[test]
1165 fn test_new_returns_error_zero_temperature() {
1166 let mut config = default_config();
1167 config.actor.temperature = 0.0;
1168 let err = PcActorCritic::new(config, 42)
1169 .map(|_: PcActorCritic| ())
1170 .unwrap_err();
1171 assert!(format!("{err}").contains("temperature"));
1172 }
1173
1174 #[test]
1175 fn test_new_returns_error_zero_input_size() {
1176 let mut config = default_config();
1177 config.actor.input_size = 0;
1178 config.critic.input_size = 0;
1179 assert!(PcActorCritic::new(config, 42)
1180 .map(|_: PcActorCritic| ())
1181 .is_err());
1182 }
1183
1184 #[test]
1185 fn test_new_returns_error_zero_output_size() {
1186 let mut config = default_config();
1187 config.actor.output_size = 0;
1188 assert!(PcActorCritic::new(config, 42)
1189 .map(|_: PcActorCritic| ())
1190 .is_err());
1191 }
1192
1193 #[test]
1194 fn test_new_returns_error_negative_gamma() {
1195 let mut config = default_config();
1196 config.gamma = -0.1;
1197 let err = PcActorCritic::new(config, 42)
1198 .map(|_: PcActorCritic| ())
1199 .unwrap_err();
1200 assert!(format!("{err}").contains("gamma"));
1201 }
1202
1203 #[test]
1206 fn test_activation_cache_new_creates_empty() {
1207 let cache: ActivationCache = ActivationCache::new(3);
1208 assert_eq!(cache.batch_size(), 0);
1209 }
1210
1211 #[test]
1212 fn test_activation_cache_record_increments_batch_size() {
1213 let mut agent: PcActorCritic = make_agent();
1214 let input = vec![0.5; 9];
1215 let valid = vec![0, 1, 2];
1216 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1217
1218 let num_hidden = infer.hidden_states.len();
1219 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1220 cache.record(&infer.hidden_states);
1221 assert_eq!(cache.batch_size(), 1);
1222 }
1223
1224 #[test]
1225 fn test_activation_cache_record_multiple() {
1226 let mut agent: PcActorCritic = make_agent();
1227 let valid = vec![0, 1, 2];
1228 let init_input = vec![0.5; 9];
1229 let num_hidden = {
1230 let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1231 infer.hidden_states.len()
1232 };
1233
1234 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1235 for i in 0..5 {
1236 let input = vec![i as f64 * 0.1; 9];
1237 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1238 cache.record(&infer.hidden_states);
1239 }
1240 assert_eq!(cache.batch_size(), 5);
1241 }
1242
1243 #[test]
1244 fn test_activation_cache_recorded_values_match_hidden_states() {
1245 let mut agent: PcActorCritic = make_agent();
1246 let input = vec![0.5; 9];
1247 let valid = vec![0, 1, 2];
1248 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1249
1250 let num_hidden = infer.hidden_states.len();
1251 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1252 cache.record(&infer.hidden_states);
1253
1254 for (layer_idx, expected) in infer.hidden_states.iter().enumerate() {
1256 let layer_data = cache.layer(layer_idx);
1257 assert_eq!(layer_data.len(), 1);
1258 assert_eq!(layer_data[0], *expected);
1259 }
1260 }
1261
1262 #[test]
1265 fn test_activation_cache_layer_count() {
1266 let mut agent: PcActorCritic = make_agent();
1267 let input = vec![0.5; 9];
1268 let valid = vec![0, 1, 2];
1269 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1270
1271 let num_hidden = infer.hidden_states.len();
1272 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1273 cache.record(&infer.hidden_states);
1274
1275 assert_eq!(cache.num_layers(), num_hidden);
1276 }
1277
1278 #[test]
1279 fn test_activation_cache_layer_sample_count() {
1280 let mut agent: PcActorCritic = make_agent();
1281 let valid = vec![0, 1, 2];
1282 let init_input = vec![0.5; 9];
1283 let num_hidden = {
1284 let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1285 infer.hidden_states.len()
1286 };
1287
1288 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1289 for i in 0..10 {
1290 let input = vec![i as f64 * 0.1; 9];
1291 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1292 cache.record(&infer.hidden_states);
1293 }
1294
1295 for layer_idx in 0..num_hidden {
1296 assert_eq!(
1297 cache.layer(layer_idx).len(),
1298 10,
1299 "Layer {layer_idx} should have 10 samples"
1300 );
1301 }
1302 }
1303
1304 fn build_caches_for_agent(
1307 agent: &mut PcActorCritic,
1308 batch_size: usize,
1309 ) -> (ActivationCache, ActivationCache) {
1310 let num_actor_hidden = agent.config.actor.hidden_layers.len();
1311 let num_critic_hidden = agent.config.critic.hidden_layers.len();
1312 let mut actor_cache: ActivationCache = ActivationCache::new(num_actor_hidden);
1313 let mut critic_cache: ActivationCache = ActivationCache::new(num_critic_hidden);
1314 let valid: Vec<usize> = (0..agent.config.actor.output_size).collect();
1315 for i in 0..batch_size {
1316 let input: Vec<f64> = (0..agent.config.actor.input_size)
1317 .map(|j| ((i * 9 + j) as f64 * 0.1).sin())
1318 .collect();
1319 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1320 actor_cache.record(&infer.hidden_states);
1321 let mut critic_input = input;
1322 critic_input.extend_from_slice(&infer.latent_concat);
1323 let (_value, critic_hidden) = agent.critic.forward_with_hidden(&critic_input);
1324 critic_cache.record(&critic_hidden);
1325 }
1326 (actor_cache, critic_cache)
1327 }
1328
1329 #[test]
1330 fn test_agent_crossover_produces_valid_agent() {
1331 let config = default_config();
1332 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1333 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1334
1335 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1336 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1337
1338 let child: PcActorCritic = PcActorCritic::crossover(
1339 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1340 )
1341 .unwrap();
1342
1343 assert_eq!(
1344 child.config.actor.hidden_layers.len(),
1345 agent_a.config.actor.hidden_layers.len()
1346 );
1347 }
1348
1349 #[test]
1350 fn test_agent_crossover_actor_weights_differ() {
1351 let config = default_config();
1352 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1353 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1354
1355 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1356 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1357
1358 let child: PcActorCritic = PcActorCritic::crossover(
1359 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1360 )
1361 .unwrap();
1362
1363 assert_ne!(
1364 child.actor.layers[0].weights.data,
1365 agent_a.actor.layers[0].weights.data
1366 );
1367 assert_ne!(
1368 child.actor.layers[0].weights.data,
1369 agent_b.actor.layers[0].weights.data
1370 );
1371 }
1372
1373 #[test]
1374 fn test_agent_crossover_critic_weights_differ() {
1375 let config = default_config();
1376 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1377 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1378
1379 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1380 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1381
1382 let child: PcActorCritic = PcActorCritic::crossover(
1383 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1384 )
1385 .unwrap();
1386
1387 assert_ne!(
1388 child.critic.layers[0].weights.data,
1389 agent_a.critic.layers[0].weights.data
1390 );
1391 assert_ne!(
1392 child.critic.layers[0].weights.data,
1393 agent_b.critic.layers[0].weights.data
1394 );
1395 }
1396
1397 #[test]
1400 fn test_agent_crossover_child_can_infer() {
1401 let config = default_config();
1402 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1403 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1404
1405 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1406 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1407
1408 let mut child: PcActorCritic = PcActorCritic::crossover(
1409 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1410 )
1411 .unwrap();
1412
1413 let input = vec![0.5; 9];
1414 let valid = vec![0, 1, 2, 3, 4];
1415 let (action, _) = child.act(&input, &valid, SelectionMode::Training);
1416 assert!(valid.contains(&action), "Action {action} not in valid set");
1417 }
1418
1419 #[test]
1420 fn test_agent_crossover_child_can_learn() {
1421 let config = default_config();
1422 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1423 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1424
1425 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1426 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1427
1428 let mut child: PcActorCritic = PcActorCritic::crossover(
1429 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1430 )
1431 .unwrap();
1432
1433 let trajectory = make_trajectory(&mut child);
1434 let loss = child.learn(&trajectory);
1435 assert!(loss.is_finite(), "Child learn loss not finite: {loss}");
1436 }
1437
1438 #[test]
1439 fn test_agent_crossover_mismatched_batch_size_error() {
1440 let config = default_config();
1441 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1442 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1443
1444 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1445 let (ac_b, _cc_b) = build_caches_for_agent(&mut agent_b, 30); let (_, cc_b_match) = build_caches_for_agent(&mut agent_b, 50);
1447
1448 let result = PcActorCritic::crossover(
1450 &agent_a,
1451 &agent_b,
1452 &ac_a,
1453 &ac_b,
1454 &cc_a,
1455 &cc_b_match,
1456 0.5,
1457 config,
1458 99,
1459 );
1460 assert!(result.is_err(), "Mismatched actor batch sizes should error");
1461 }
1462
1463 #[test]
1466 fn test_agent_crossover_with_separate_critic_caches() {
1467 let config = default_config();
1468 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1469 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1470
1471 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1472 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1473
1474 let child: PcActorCritic = PcActorCritic::crossover(
1475 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1476 )
1477 .unwrap();
1478
1479 assert_eq!(child.critic.layers.len(), agent_a.critic.layers.len());
1480 }
1481
1482 #[test]
1483 fn test_agent_crossover_critic_uses_own_caches() {
1484 let config = default_config();
1485 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1486 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1487
1488 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1489 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1490
1491 let child: PcActorCritic = PcActorCritic::crossover(
1492 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1493 )
1494 .unwrap();
1495
1496 assert_ne!(
1497 child.critic.layers[0].weights.data,
1498 agent_a.critic.layers[0].weights.data
1499 );
1500 assert_ne!(
1501 child.critic.layers[0].weights.data,
1502 agent_b.critic.layers[0].weights.data
1503 );
1504 }
1505
1506 #[test]
1507 fn test_agent_crossover_mismatched_critic_batch_error() {
1508 let config = default_config();
1509 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1510 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1511
1512 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1513 let (ac_b, _) = build_caches_for_agent(&mut agent_b, 50);
1514 let (_, cc_b_small) = build_caches_for_agent(&mut agent_b, 30);
1516
1517 let result = PcActorCritic::crossover(
1518 &agent_a,
1519 &agent_b,
1520 &ac_a,
1521 &ac_b,
1522 &cc_a,
1523 &cc_b_small,
1524 0.5,
1525 config,
1526 99,
1527 );
1528 assert!(
1529 result.is_err(),
1530 "Mismatched critic batch sizes should error"
1531 );
1532 }
1533
1534 #[test]
1537 fn test_activation_cache_accessible_from_crate() {
1538 let _cache: crate::pc_actor_critic::ActivationCache = ActivationCache::new(1);
1540 }
1541
1542 #[test]
1543 fn test_cca_neuron_alignment_accessible_from_crate() {
1544 use crate::linalg::cpu::CpuLinAlg;
1546 use crate::linalg::LinAlg;
1547 let mat = CpuLinAlg::zeros_mat(10, 3);
1548 let _perm = crate::matrix::cca_neuron_alignment::<CpuLinAlg>(&mat, &mat).unwrap();
1549 }
1550}