1use std::collections::VecDeque;
15
16use rand::rngs::StdRng;
17use serde::{Deserialize, Serialize};
18
19use crate::error::PcError;
20use crate::linalg::cpu::CpuLinAlg;
21use crate::linalg::LinAlg;
22use crate::mlp_critic::{MlpCritic, MlpCriticConfig};
23use crate::pc_actor::{InferResult, PcActor, PcActorConfig, SelectionMode};
24
25fn default_surprise_buffer_size() -> usize {
27 100
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct PcActorCriticConfig {
69 pub actor: PcActorConfig,
71 pub critic: MlpCriticConfig,
73 pub gamma: f64,
75 pub surprise_low: f64,
77 pub surprise_high: f64,
79 pub adaptive_surprise: bool,
81 #[serde(default = "default_surprise_buffer_size")]
84 pub surprise_buffer_size: usize,
85 pub entropy_coeff: f64,
87}
88
89#[derive(Debug, Clone)]
93pub struct TrajectoryStep<L: LinAlg = CpuLinAlg> {
94 pub input: L::Vector,
96 pub latent_concat: L::Vector,
98 pub y_conv: L::Vector,
100 pub hidden_states: Vec<L::Vector>,
102 pub prediction_errors: Vec<L::Vector>,
104 pub tanh_components: Vec<Option<L::Vector>>,
106 pub action: usize,
108 pub valid_actions: Vec<usize>,
110 pub reward: f64,
112 pub surprise_score: f64,
114 pub steps_used: usize,
116}
117
118#[derive(Debug, Clone)]
137pub struct ActivationCache<L: LinAlg = CpuLinAlg> {
138 layers: Vec<Vec<L::Vector>>,
140}
141
142impl<L: LinAlg> ActivationCache<L> {
143 pub fn new(num_layers: usize) -> Self {
149 Self {
150 layers: (0..num_layers).map(|_| Vec::new()).collect(),
151 }
152 }
153
154 pub fn batch_size(&self) -> usize {
156 self.layers.first().map_or(0, |l| l.len())
157 }
158
159 pub fn num_layers(&self) -> usize {
161 self.layers.len()
162 }
163
164 pub fn record(&mut self, hidden_states: &[L::Vector]) {
170 for (layer, state) in self.layers.iter_mut().zip(hidden_states.iter()) {
171 layer.push(state.clone());
172 }
173 }
174
175 pub fn layer(&self, layer_idx: usize) -> &[L::Vector] {
181 &self.layers[layer_idx]
182 }
183}
184
185#[derive(Debug)]
192pub struct PcActorCritic<L: LinAlg = CpuLinAlg> {
193 pub(crate) actor: PcActor<L>,
195 pub(crate) critic: MlpCritic<L>,
197 pub config: PcActorCriticConfig,
199 rng: StdRng,
201 surprise_buffer: VecDeque<f64>,
203}
204
205impl<L: LinAlg> PcActorCritic<L> {
206 pub fn new(config: PcActorCriticConfig, seed: u64) -> Result<Self, PcError> {
217 if !(0.0..=1.0).contains(&config.gamma) {
218 return Err(PcError::ConfigValidation(format!(
219 "gamma must be in [0.0, 1.0], got {}",
220 config.gamma
221 )));
222 }
223
224 use rand::SeedableRng;
225 let mut rng = StdRng::seed_from_u64(seed);
226 let actor = PcActor::<L>::new(config.actor.clone(), &mut rng)?;
227 let critic = MlpCritic::<L>::new(config.critic.clone(), &mut rng)?;
228 Ok(Self {
229 actor,
230 critic,
231 config,
232 rng,
233 surprise_buffer: VecDeque::new(),
234 })
235 }
236
237 #[allow(clippy::too_many_arguments)]
257 pub fn crossover(
258 parent_a: &PcActorCritic<L>,
259 parent_b: &PcActorCritic<L>,
260 actor_cache_a: &ActivationCache<L>,
261 actor_cache_b: &ActivationCache<L>,
262 critic_cache_a: &ActivationCache<L>,
263 critic_cache_b: &ActivationCache<L>,
264 alpha: f64,
265 child_config: PcActorCriticConfig,
266 seed: u64,
267 ) -> Result<Self, PcError> {
268 if actor_cache_a.batch_size() != actor_cache_b.batch_size() {
270 return Err(PcError::DimensionMismatch {
271 expected: actor_cache_a.batch_size(),
272 got: actor_cache_b.batch_size(),
273 context: "actor activation cache batch sizes must match for crossover",
274 });
275 }
276 if critic_cache_a.batch_size() != critic_cache_b.batch_size() {
278 return Err(PcError::DimensionMismatch {
279 expected: critic_cache_a.batch_size(),
280 got: critic_cache_b.batch_size(),
281 context: "critic activation cache batch sizes must match for crossover",
282 });
283 }
284
285 let actor_cache_mats_a = cache_to_matrices::<L>(actor_cache_a);
287 let actor_cache_mats_b = cache_to_matrices::<L>(actor_cache_b);
288 let critic_cache_mats_a = cache_to_matrices::<L>(critic_cache_a);
289 let critic_cache_mats_b = cache_to_matrices::<L>(critic_cache_b);
290
291 use rand::SeedableRng;
292 let mut rng = StdRng::seed_from_u64(seed);
293
294 let actor = PcActor::<L>::crossover(
296 &parent_a.actor,
297 &parent_b.actor,
298 &actor_cache_mats_a,
299 &actor_cache_mats_b,
300 alpha,
301 child_config.actor.clone(),
302 &mut rng,
303 )?;
304
305 let critic = MlpCritic::<L>::crossover(
307 &parent_a.critic,
308 &parent_b.critic,
309 &critic_cache_mats_a,
310 &critic_cache_mats_b,
311 alpha,
312 child_config.critic.clone(),
313 &mut rng,
314 )?;
315
316 Ok(Self {
317 actor,
318 critic,
319 config: child_config,
320 rng,
321 surprise_buffer: VecDeque::new(),
322 })
323 }
324
325 pub fn from_parts(
334 config: PcActorCriticConfig,
335 actor: PcActor<L>,
336 critic: MlpCritic<L>,
337 rng: StdRng,
338 ) -> Self {
339 Self {
340 actor,
341 critic,
342 config,
343 rng,
344 surprise_buffer: VecDeque::new(),
345 }
346 }
347
348 pub fn infer(&self, input: &[f64]) -> InferResult<L> {
361 self.actor.infer(input)
362 }
363
364 pub fn act(
379 &mut self,
380 input: &[f64],
381 valid_actions: &[usize],
382 mode: SelectionMode,
383 ) -> (usize, InferResult<L>) {
384 let infer_result = self.actor.infer(input);
385 let action =
386 self.actor
387 .select_action(&infer_result.y_conv, valid_actions, mode, &mut self.rng);
388 (action, infer_result)
389 }
390
391 pub fn learn(&mut self, trajectory: &[TrajectoryStep<L>]) -> f64 {
404 if trajectory.is_empty() {
405 return 0.0;
406 }
407
408 let n = trajectory.len();
409
410 let mut returns = vec![0.0; n];
412 returns[n - 1] = trajectory[n - 1].reward;
413 for t in (0..n - 1).rev() {
414 returns[t] = trajectory[t].reward + self.config.gamma * returns[t + 1];
415 }
416
417 let mut total_loss = 0.0;
418
419 for (t, step) in trajectory.iter().enumerate() {
420 let input_vec = L::vec_to_vec(&step.input);
422 let latent_vec = L::vec_to_vec(&step.latent_concat);
423 let mut critic_input = input_vec.clone();
424 critic_input.extend_from_slice(&latent_vec);
425
426 let value = self.critic.forward(&critic_input);
428 let advantage = returns[t] - value;
429
430 let loss = self.critic.update(&critic_input, returns[t]);
432 total_loss += loss;
433
434 let y_conv_vec = L::vec_to_vec(&step.y_conv);
436 let scaled: Vec<f64> = y_conv_vec
437 .iter()
438 .map(|&v| v / self.actor.config.temperature)
439 .collect();
440 let scaled_l = L::vec_from_slice(&scaled);
441 let pi_l = L::softmax_masked(&scaled_l, &step.valid_actions);
442 let pi = L::vec_to_vec(&pi_l);
443
444 let mut delta = vec![0.0; pi.len()];
445 for &i in &step.valid_actions {
446 delta[i] = pi[i];
447 }
448 delta[step.action] -= 1.0;
449
450 for &i in &step.valid_actions {
452 delta[i] *= advantage;
453 }
454
455 for &i in &step.valid_actions {
457 let log_pi = (pi[i].max(1e-10)).ln();
458 delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
459 }
460
461 let s_scale = self.surprise_scale(step.surprise_score);
463
464 let stored_infer = InferResult {
465 y_conv: step.y_conv.clone(),
466 latent_concat: step.latent_concat.clone(),
467 hidden_states: step.hidden_states.clone(),
468 prediction_errors: step.prediction_errors.clone(),
469 surprise_score: step.surprise_score,
470 steps_used: step.steps_used,
471 converged: false,
472 tanh_components: step.tanh_components.clone(),
473 };
474 self.actor
475 .update_weights(&delta, &stored_infer, &input_vec, s_scale);
476
477 if self.config.adaptive_surprise {
479 self.push_surprise(step.surprise_score);
480 }
481 }
482
483 total_loss / n as f64
484 }
485
486 #[allow(clippy::too_many_arguments)]
503 pub fn learn_continuous(
504 &mut self,
505 input: &[f64],
506 infer: &InferResult<L>,
507 action: usize,
508 valid_actions: &[usize],
509 reward: f64,
510 next_input: &[f64],
511 next_infer: &InferResult<L>,
512 terminal: bool,
513 ) -> f64 {
514 let latent_vec = L::vec_to_vec(&infer.latent_concat);
516 let mut critic_input = input.to_vec();
517 critic_input.extend_from_slice(&latent_vec);
518
519 let next_latent_vec = L::vec_to_vec(&next_infer.latent_concat);
520 let mut next_critic_input = next_input.to_vec();
521 next_critic_input.extend_from_slice(&next_latent_vec);
522
523 let v_s = self.critic.forward(&critic_input);
524 let v_next = if terminal {
525 0.0
526 } else {
527 self.critic.forward(&next_critic_input)
528 };
529
530 let target = reward
531 + if terminal {
532 0.0
533 } else {
534 self.config.gamma * v_next
535 };
536 let td_error = target - v_s;
537
538 let loss = self.critic.update(&critic_input, target);
540
541 let y_conv_vec = L::vec_to_vec(&infer.y_conv);
543 let scaled: Vec<f64> = y_conv_vec
544 .iter()
545 .map(|&v| v / self.actor.config.temperature)
546 .collect();
547 let scaled_l = L::vec_from_slice(&scaled);
548 let pi_l = L::softmax_masked(&scaled_l, valid_actions);
549 let pi = L::vec_to_vec(&pi_l);
550
551 let mut delta = vec![0.0; pi.len()];
552 for &i in valid_actions {
553 delta[i] = pi[i];
554 }
555 delta[action] -= 1.0;
556
557 for &i in valid_actions {
558 delta[i] *= td_error;
559 }
560
561 for &i in valid_actions {
563 let log_pi = (pi[i].max(1e-10)).ln();
564 delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
565 }
566
567 let s_scale = self.surprise_scale(infer.surprise_score);
568 self.actor.update_weights(&delta, infer, input, s_scale);
569
570 if self.config.adaptive_surprise {
571 self.push_surprise(infer.surprise_score);
572 }
573
574 loss
575 }
576
577 pub fn surprise_scale(&self, surprise: f64) -> f64 {
586 let (low, high) = if self.config.adaptive_surprise && self.surprise_buffer.len() >= 10 {
587 let mean = self.surprise_buffer.iter().sum::<f64>() / self.surprise_buffer.len() as f64;
588 let variance = self
589 .surprise_buffer
590 .iter()
591 .map(|&s| (s - mean) * (s - mean))
592 .sum::<f64>()
593 / self.surprise_buffer.len() as f64;
594 let std = variance.sqrt();
595 let lo = (mean - 0.5 * std).max(0.0);
596 let hi = mean + 1.5 * std;
597 (lo, hi)
598 } else {
599 (self.config.surprise_low, self.config.surprise_high)
600 };
601
602 if surprise <= low {
603 0.1
604 } else if surprise >= high {
605 2.0
606 } else {
607 let t = (surprise - low) / (high - low);
609 0.1 + t * (2.0 - 0.1)
610 }
611 }
612
613 fn push_surprise(&mut self, surprise: f64) {
615 if self.surprise_buffer.len() >= self.config.surprise_buffer_size {
616 self.surprise_buffer.pop_front();
617 }
618 self.surprise_buffer.push_back(surprise);
619 }
620}
621
622fn cache_to_matrices<L: LinAlg>(cache: &ActivationCache<L>) -> Vec<L::Matrix> {
625 let num_layers = cache.num_layers();
626 let batch_size = cache.batch_size();
627 let mut matrices = Vec::with_capacity(num_layers);
628
629 for layer_idx in 0..num_layers {
630 let samples = cache.layer(layer_idx);
631 if samples.is_empty() {
632 matrices.push(L::zeros_mat(0, 0));
633 continue;
634 }
635 let n_neurons = L::vec_len(&samples[0]);
636 let mut mat = L::zeros_mat(batch_size, n_neurons);
637 for (r, sample) in samples.iter().enumerate() {
638 for c in 0..n_neurons {
639 L::mat_set(&mut mat, r, c, L::vec_get(sample, c));
640 }
641 }
642 matrices.push(mat);
643 }
644
645 matrices
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651 use crate::activation::Activation;
652 use crate::layer::LayerDef;
653 use crate::pc_actor::SelectionMode;
654
655 fn default_config() -> PcActorCriticConfig {
656 PcActorCriticConfig {
657 actor: PcActorConfig {
658 input_size: 9,
659 hidden_layers: vec![LayerDef {
660 size: 18,
661 activation: Activation::Tanh,
662 }],
663 output_size: 9,
664 output_activation: Activation::Tanh,
665 alpha: 0.1,
666 tol: 0.01,
667 min_steps: 1,
668 max_steps: 20,
669 lr_weights: 0.01,
670 synchronous: true,
671 temperature: 1.0,
672 local_lambda: 1.0,
673 residual: false,
674 rezero_init: 0.001,
675 },
676 critic: MlpCriticConfig {
677 input_size: 27,
678 hidden_layers: vec![LayerDef {
679 size: 36,
680 activation: Activation::Tanh,
681 }],
682 output_activation: Activation::Linear,
683 lr: 0.005,
684 },
685 gamma: 0.95,
686 surprise_low: 0.02,
687 surprise_high: 0.15,
688 adaptive_surprise: false,
689 surprise_buffer_size: 100,
690 entropy_coeff: 0.01,
691 }
692 }
693
694 fn make_agent() -> PcActorCritic {
695 let agent: PcActorCritic = PcActorCritic::new(default_config(), 42).unwrap();
696 agent
697 }
698
699 fn make_trajectory(agent: &mut PcActorCritic) -> Vec<TrajectoryStep> {
700 let input = vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5];
701 let valid = vec![2, 7];
702 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
703 vec![TrajectoryStep {
704 input,
705 latent_concat: infer.latent_concat,
706 y_conv: infer.y_conv,
707 hidden_states: infer.hidden_states,
708 prediction_errors: infer.prediction_errors,
709 tanh_components: infer.tanh_components,
710 action,
711 valid_actions: valid,
712 reward: 1.0,
713 surprise_score: infer.surprise_score,
714 steps_used: infer.steps_used,
715 }]
716 }
717
718 #[test]
721 fn test_learn_empty_returns_zero_without_modifying_weights() {
722 let mut agent: PcActorCritic = make_agent();
723 let w_before = agent.actor.layers[0].weights.data.clone();
724 let cw_before = agent.critic.layers[0].weights.data.clone();
725 let loss = agent.learn(&[]);
726 assert_eq!(loss, 0.0);
727 assert_eq!(agent.actor.layers[0].weights.data, w_before);
728 assert_eq!(agent.critic.layers[0].weights.data, cw_before);
729 }
730
731 #[test]
732 fn test_learn_updates_actor_weights() {
733 let mut agent: PcActorCritic = make_agent();
734 let trajectory = make_trajectory(&mut agent);
735 let w_before = agent.actor.layers[0].weights.data.clone();
736 let _ = agent.learn(&trajectory);
737 assert_ne!(agent.actor.layers[0].weights.data, w_before);
738 }
739
740 #[test]
741 fn test_learn_updates_critic_weights() {
742 let mut agent: PcActorCritic = make_agent();
743 let trajectory = make_trajectory(&mut agent);
744 let w_before = agent.critic.layers[0].weights.data.clone();
745 let _ = agent.learn(&trajectory);
746 assert_ne!(agent.critic.layers[0].weights.data, w_before);
747 }
748
749 #[test]
750 fn test_learn_returns_finite_nonneg_loss() {
751 let mut agent: PcActorCritic = make_agent();
752 let trajectory = make_trajectory(&mut agent);
753 let loss = agent.learn(&trajectory);
754 assert!(loss.is_finite(), "Loss {loss} is not finite");
755 assert!(loss >= 0.0, "Loss {loss} is negative");
756 }
757
758 #[test]
759 fn test_learn_single_step_trajectory() {
760 let mut agent: PcActorCritic = make_agent();
761 let input = vec![0.5; 9];
762 let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
763 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
764 let trajectory = vec![TrajectoryStep {
765 input,
766 latent_concat: infer.latent_concat,
767 y_conv: infer.y_conv,
768 hidden_states: infer.hidden_states,
769 prediction_errors: infer.prediction_errors,
770 tanh_components: infer.tanh_components,
771 action,
772 valid_actions: valid,
773 reward: -1.0,
774 surprise_score: infer.surprise_score,
775 steps_used: infer.steps_used,
776 }];
777 let loss = agent.learn(&trajectory);
778 assert!(loss.is_finite());
779 }
780
781 #[test]
782 fn test_learn_multi_step_uses_stored_hidden_states() {
783 let mut agent: PcActorCritic = make_agent();
785 let inputs = [
786 vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5],
787 vec![0.5, 0.5, -1.0, 0.0, 1.0, -0.5, 0.0, -1.0, 0.5],
788 vec![-1.0, 0.0, 1.0, -0.5, 0.5, 0.0, 1.0, -1.0, -0.5],
789 ];
790 let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
791
792 let mut trajectory = Vec::new();
793 for (i, inp) in inputs.iter().enumerate() {
794 let (action, infer) = agent.act(inp, &valid, SelectionMode::Training);
795 trajectory.push(TrajectoryStep {
796 input: inp.clone(),
797 latent_concat: infer.latent_concat,
798 y_conv: infer.y_conv,
799 hidden_states: infer.hidden_states,
800 prediction_errors: infer.prediction_errors,
801 tanh_components: infer.tanh_components,
802 action,
803 valid_actions: valid.clone(),
804 reward: if i == 2 { 1.0 } else { 0.0 },
805 surprise_score: infer.surprise_score,
806 steps_used: infer.steps_used,
807 });
808 }
809
810 let loss = agent.learn(&trajectory);
811 assert!(
812 loss.is_finite(),
813 "Multi-step learn should produce finite loss"
814 );
815 assert!(loss >= 0.0);
816 }
817
818 #[test]
821 fn test_learn_continuous_nonterminal_uses_next_value() {
822 let mut agent: PcActorCritic = make_agent();
823 let input = vec![0.5; 9];
824 let next_input = vec![-0.5; 9];
825 let valid = vec![0, 1, 2];
826 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
827 let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
828
829 let loss = agent.learn_continuous(
831 &input,
832 &infer,
833 action,
834 &valid,
835 0.5,
836 &next_input,
837 &next_infer,
838 false,
839 );
840 assert!(loss.is_finite());
841 }
842
843 #[test]
844 fn test_learn_continuous_terminal_uses_reward_only() {
845 let mut agent: PcActorCritic = make_agent();
846 let input = vec![0.5; 9];
847 let next_input = vec![0.0; 9];
848 let valid = vec![0, 1, 2];
849 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
850 let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
851
852 let loss = agent.learn_continuous(
854 &input,
855 &infer,
856 action,
857 &valid,
858 1.0,
859 &next_input,
860 &next_infer,
861 true,
862 );
863 assert!(loss.is_finite());
864 }
865
866 #[test]
867 fn test_learn_continuous_terminal_and_nonterminal_produce_different_updates() {
868 let config = default_config();
870 let mut agent_term: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
871 let mut agent_nonterm: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
872
873 let input = vec![0.5; 9];
874 let next_input = vec![-0.5; 9];
875 let valid = vec![0, 1, 2];
876
877 let (action, infer) = agent_term.act(&input, &valid, SelectionMode::Training);
879 let (_, next_infer) = agent_term.act(&next_input, &valid, SelectionMode::Training);
880
881 let (action2, infer2) = agent_nonterm.act(&input, &valid, SelectionMode::Training);
883 let (_, next_infer2) = agent_nonterm.act(&next_input, &valid, SelectionMode::Training);
884
885 let loss_term = agent_term.learn_continuous(
887 &input,
888 &infer,
889 action,
890 &valid,
891 1.0,
892 &next_input,
893 &next_infer,
894 true,
895 );
896
897 let loss_nonterm = agent_nonterm.learn_continuous(
899 &input,
900 &infer2,
901 action2,
902 &valid,
903 1.0,
904 &next_input,
905 &next_infer2,
906 false,
907 );
908
909 assert!(
912 (loss_term - loss_nonterm).abs() > 1e-15,
913 "Terminal and non-terminal should produce different losses: {loss_term} vs {loss_nonterm}"
914 );
915 }
916
917 #[test]
918 fn test_learn_continuous_updates_actor() {
919 let mut agent: PcActorCritic = make_agent();
920 let input = vec![0.5; 9];
921 let next_input = vec![-0.5; 9];
922 let valid = vec![0, 1, 2];
923 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
924 let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
925 let w_before = agent.actor.layers[0].weights.data.clone();
926 let _ = agent.learn_continuous(
927 &input,
928 &infer,
929 action,
930 &valid,
931 1.0,
932 &next_input,
933 &next_infer,
934 false,
935 );
936 assert_ne!(agent.actor.layers[0].weights.data, w_before);
937 }
938
939 #[test]
942 fn test_surprise_scale_below_low() {
943 let agent: PcActorCritic = make_agent();
944 let scale = agent.surprise_scale(0.01); assert!((scale - 0.1).abs() < 1e-12, "Expected 0.1, got {scale}");
946 }
947
948 #[test]
949 fn test_surprise_scale_above_high() {
950 let agent: PcActorCritic = make_agent();
951 let scale = agent.surprise_scale(0.20); assert!((scale - 2.0).abs() < 1e-12, "Expected 2.0, got {scale}");
953 }
954
955 #[test]
956 fn test_surprise_scale_midpoint_in_range() {
957 let agent: PcActorCritic = make_agent();
958 let midpoint = (0.02 + 0.15) / 2.0;
959 let scale = agent.surprise_scale(midpoint);
960 assert!(
961 scale > 0.1 && scale < 2.0,
962 "Midpoint scale {scale} out of range"
963 );
964 }
965
966 #[test]
967 fn test_surprise_scale_monotone_increasing() {
968 let agent: PcActorCritic = make_agent();
969 let s1 = agent.surprise_scale(0.01);
970 let s2 = agent.surprise_scale(0.05);
971 let s3 = agent.surprise_scale(0.10);
972 let s4 = agent.surprise_scale(0.20);
973 assert!(s1 <= s2, "s1={s1} > s2={s2}");
974 assert!(s2 <= s3, "s2={s2} > s3={s3}");
975 assert!(s3 <= s4, "s3={s3} > s4={s4}");
976 }
977
978 #[test]
979 fn test_adaptive_surprise_recalibrates_thresholds_after_many_episodes() {
980 let mut config = default_config();
981 config.adaptive_surprise = true;
982 let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
983
984 for i in 0..15 {
986 agent.push_surprise(0.1 + 0.02 * i as f64);
987 }
988
989 let scale_low = agent.surprise_scale(0.0);
996 assert!(
997 (scale_low - 0.1).abs() < 1e-12,
998 "Expected 0.1 below adaptive low: got {scale_low}"
999 );
1000
1001 let scale_high = agent.surprise_scale(1.0);
1003 assert!(
1004 (scale_high - 2.0).abs() < 1e-12,
1005 "Expected 2.0 above adaptive high: got {scale_high}"
1006 );
1007
1008 let scale_mid = agent.surprise_scale(0.24);
1010 assert!(
1011 scale_mid > 0.1 && scale_mid < 2.0,
1012 "Expected interpolated value at mean, got {scale_mid}"
1013 );
1014 }
1015
1016 #[test]
1017 fn test_entropy_regularization_prevents_policy_collapse() {
1018 let mut config = default_config();
1021 config.entropy_coeff = 0.1; let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1023
1024 let input = vec![0.5; 9];
1025 let valid: Vec<usize> = (0..9).collect();
1026
1027 for _ in 0..20 {
1029 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
1030 let trajectory = vec![TrajectoryStep {
1031 input: input.clone(),
1032 latent_concat: infer.latent_concat,
1033 y_conv: infer.y_conv,
1034 hidden_states: infer.hidden_states,
1035 prediction_errors: infer.prediction_errors,
1036 tanh_components: infer.tanh_components,
1037 action,
1038 valid_actions: valid.clone(),
1039 reward: 1.0,
1040 surprise_score: infer.surprise_score,
1041 steps_used: infer.steps_used,
1042 }];
1043 let _ = agent.learn(&trajectory);
1044 }
1045
1046 let mut seen = std::collections::HashSet::new();
1048 for _ in 0..50 {
1049 let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1050 seen.insert(action);
1051 }
1052 assert!(
1053 seen.len() > 1,
1054 "Entropy regularization should prevent collapse to single action, but only saw {:?}",
1055 seen
1056 );
1057 }
1058
1059 #[test]
1062 fn test_act_returns_valid_action() {
1063 let mut agent: PcActorCritic = make_agent();
1064 let input = vec![0.5; 9];
1065 let valid = vec![1, 3, 5, 7];
1066 for _ in 0..20 {
1067 let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1068 assert!(valid.contains(&action), "Action {action} not in valid set");
1069 }
1070 }
1071
1072 #[test]
1073 #[should_panic]
1074 fn test_act_empty_valid_panics() {
1075 let mut agent: PcActorCritic = make_agent();
1076 let input = vec![0.5; 9];
1077 let _ = agent.act(&input, &[], SelectionMode::Training);
1078 }
1079
1080 #[test]
1083 fn test_learn_improves_policy_for_rewarded_action() {
1084 let config = PcActorCriticConfig {
1086 actor: PcActorConfig {
1087 input_size: 9,
1088 hidden_layers: vec![LayerDef {
1089 size: 18,
1090 activation: Activation::Tanh,
1091 }],
1092 output_size: 9,
1093 output_activation: Activation::Linear,
1094 alpha: 0.1,
1095 tol: 0.01,
1096 min_steps: 1,
1097 max_steps: 5,
1098 lr_weights: 0.01,
1099 synchronous: true,
1100 temperature: 1.0,
1101 local_lambda: 1.0,
1102 residual: false,
1103 rezero_init: 0.001,
1104 },
1105 critic: MlpCriticConfig {
1106 input_size: 27,
1107 hidden_layers: vec![LayerDef {
1108 size: 36,
1109 activation: Activation::Tanh,
1110 }],
1111 output_activation: Activation::Linear,
1112 lr: 0.005,
1113 },
1114 gamma: 0.99,
1115 surprise_low: 0.02,
1116 surprise_high: 0.15,
1117 adaptive_surprise: false,
1118 surprise_buffer_size: 100,
1119 entropy_coeff: 0.0, };
1121 let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1122
1123 let input = vec![0.0; 9];
1124 let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
1125 let target_action = 4; for _ in 0..200 {
1129 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1130 let trajectory = vec![TrajectoryStep {
1131 input: input.clone(),
1132 latent_concat: infer.latent_concat,
1133 y_conv: infer.y_conv,
1134 hidden_states: infer.hidden_states,
1135 prediction_errors: infer.prediction_errors,
1136 tanh_components: infer.tanh_components,
1137 action: target_action,
1138 valid_actions: valid.clone(),
1139 reward: 1.0,
1140 surprise_score: infer.surprise_score,
1141 steps_used: infer.steps_used,
1142 }];
1143 agent.learn(&trajectory);
1144 }
1145
1146 let (action, infer) = agent.act(&input, &valid, SelectionMode::Play);
1149
1150 let logit_4 = infer.y_conv[4];
1152 let max_other = valid
1153 .iter()
1154 .filter(|&&a| a != 4)
1155 .map(|&a| infer.y_conv[a])
1156 .fold(f64::NEG_INFINITY, f64::max);
1157
1158 eprintln!(
1159 "DIAGNOSTIC: action={action}, logit[4]={logit_4:.4}, max_other={max_other:.4}, \
1160 y_conv={:?}",
1161 infer
1162 .y_conv
1163 .iter()
1164 .map(|v| format!("{v:.3}"))
1165 .collect::<Vec<_>>()
1166 );
1167
1168 assert_eq!(
1169 action, target_action,
1170 "After 200 episodes rewarding action 4, agent should prefer it. Got action {action}"
1171 );
1172 }
1173
1174 #[test]
1177 fn test_new_returns_error_zero_temperature() {
1178 let mut config = default_config();
1179 config.actor.temperature = 0.0;
1180 let err = PcActorCritic::new(config, 42)
1181 .map(|_: PcActorCritic| ())
1182 .unwrap_err();
1183 assert!(format!("{err}").contains("temperature"));
1184 }
1185
1186 #[test]
1187 fn test_new_returns_error_zero_input_size() {
1188 let mut config = default_config();
1189 config.actor.input_size = 0;
1190 config.critic.input_size = 0;
1191 assert!(PcActorCritic::new(config, 42)
1192 .map(|_: PcActorCritic| ())
1193 .is_err());
1194 }
1195
1196 #[test]
1197 fn test_new_returns_error_zero_output_size() {
1198 let mut config = default_config();
1199 config.actor.output_size = 0;
1200 assert!(PcActorCritic::new(config, 42)
1201 .map(|_: PcActorCritic| ())
1202 .is_err());
1203 }
1204
1205 #[test]
1206 fn test_new_returns_error_negative_gamma() {
1207 let mut config = default_config();
1208 config.gamma = -0.1;
1209 let err = PcActorCritic::new(config, 42)
1210 .map(|_: PcActorCritic| ())
1211 .unwrap_err();
1212 assert!(format!("{err}").contains("gamma"));
1213 }
1214
1215 #[test]
1218 fn test_activation_cache_new_creates_empty() {
1219 let cache: ActivationCache = ActivationCache::new(3);
1220 assert_eq!(cache.batch_size(), 0);
1221 }
1222
1223 #[test]
1224 fn test_activation_cache_record_increments_batch_size() {
1225 let mut agent: PcActorCritic = make_agent();
1226 let input = vec![0.5; 9];
1227 let valid = vec![0, 1, 2];
1228 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1229
1230 let num_hidden = infer.hidden_states.len();
1231 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1232 cache.record(&infer.hidden_states);
1233 assert_eq!(cache.batch_size(), 1);
1234 }
1235
1236 #[test]
1237 fn test_activation_cache_record_multiple() {
1238 let mut agent: PcActorCritic = make_agent();
1239 let valid = vec![0, 1, 2];
1240 let init_input = vec![0.5; 9];
1241 let num_hidden = {
1242 let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1243 infer.hidden_states.len()
1244 };
1245
1246 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1247 for i in 0..5 {
1248 let input = vec![i as f64 * 0.1; 9];
1249 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1250 cache.record(&infer.hidden_states);
1251 }
1252 assert_eq!(cache.batch_size(), 5);
1253 }
1254
1255 #[test]
1256 fn test_activation_cache_recorded_values_match_hidden_states() {
1257 let mut agent: PcActorCritic = make_agent();
1258 let input = vec![0.5; 9];
1259 let valid = vec![0, 1, 2];
1260 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1261
1262 let num_hidden = infer.hidden_states.len();
1263 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1264 cache.record(&infer.hidden_states);
1265
1266 for (layer_idx, expected) in infer.hidden_states.iter().enumerate() {
1268 let layer_data = cache.layer(layer_idx);
1269 assert_eq!(layer_data.len(), 1);
1270 assert_eq!(layer_data[0], *expected);
1271 }
1272 }
1273
1274 #[test]
1277 fn test_activation_cache_layer_count() {
1278 let mut agent: PcActorCritic = make_agent();
1279 let input = vec![0.5; 9];
1280 let valid = vec![0, 1, 2];
1281 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1282
1283 let num_hidden = infer.hidden_states.len();
1284 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1285 cache.record(&infer.hidden_states);
1286
1287 assert_eq!(cache.num_layers(), num_hidden);
1288 }
1289
1290 #[test]
1291 fn test_activation_cache_layer_sample_count() {
1292 let mut agent: PcActorCritic = make_agent();
1293 let valid = vec![0, 1, 2];
1294 let init_input = vec![0.5; 9];
1295 let num_hidden = {
1296 let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1297 infer.hidden_states.len()
1298 };
1299
1300 let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1301 for i in 0..10 {
1302 let input = vec![i as f64 * 0.1; 9];
1303 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1304 cache.record(&infer.hidden_states);
1305 }
1306
1307 for layer_idx in 0..num_hidden {
1308 assert_eq!(
1309 cache.layer(layer_idx).len(),
1310 10,
1311 "Layer {layer_idx} should have 10 samples"
1312 );
1313 }
1314 }
1315
1316 fn build_caches_for_agent(
1319 agent: &mut PcActorCritic,
1320 batch_size: usize,
1321 ) -> (ActivationCache, ActivationCache) {
1322 let num_actor_hidden = agent.config.actor.hidden_layers.len();
1323 let num_critic_hidden = agent.config.critic.hidden_layers.len();
1324 let mut actor_cache: ActivationCache = ActivationCache::new(num_actor_hidden);
1325 let mut critic_cache: ActivationCache = ActivationCache::new(num_critic_hidden);
1326 let valid: Vec<usize> = (0..agent.config.actor.output_size).collect();
1327 for i in 0..batch_size {
1328 let input: Vec<f64> = (0..agent.config.actor.input_size)
1329 .map(|j| ((i * 9 + j) as f64 * 0.1).sin())
1330 .collect();
1331 let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1332 actor_cache.record(&infer.hidden_states);
1333 let mut critic_input = input;
1334 critic_input.extend_from_slice(&infer.latent_concat);
1335 let (_value, critic_hidden) = agent.critic.forward_with_hidden(&critic_input);
1336 critic_cache.record(&critic_hidden);
1337 }
1338 (actor_cache, critic_cache)
1339 }
1340
1341 #[test]
1342 fn test_agent_crossover_produces_valid_agent() {
1343 let config = default_config();
1344 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1345 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1346
1347 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1348 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1349
1350 let child: PcActorCritic = PcActorCritic::crossover(
1351 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1352 )
1353 .unwrap();
1354
1355 assert_eq!(
1356 child.config.actor.hidden_layers.len(),
1357 agent_a.config.actor.hidden_layers.len()
1358 );
1359 }
1360
1361 #[test]
1362 fn test_agent_crossover_actor_weights_differ() {
1363 let config = default_config();
1364 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1365 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1366
1367 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1368 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1369
1370 let child: PcActorCritic = PcActorCritic::crossover(
1371 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1372 )
1373 .unwrap();
1374
1375 assert_ne!(
1376 child.actor.layers[0].weights.data,
1377 agent_a.actor.layers[0].weights.data
1378 );
1379 assert_ne!(
1380 child.actor.layers[0].weights.data,
1381 agent_b.actor.layers[0].weights.data
1382 );
1383 }
1384
1385 #[test]
1386 fn test_agent_crossover_critic_weights_differ() {
1387 let config = default_config();
1388 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1389 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1390
1391 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1392 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1393
1394 let child: PcActorCritic = PcActorCritic::crossover(
1395 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1396 )
1397 .unwrap();
1398
1399 assert_ne!(
1400 child.critic.layers[0].weights.data,
1401 agent_a.critic.layers[0].weights.data
1402 );
1403 assert_ne!(
1404 child.critic.layers[0].weights.data,
1405 agent_b.critic.layers[0].weights.data
1406 );
1407 }
1408
1409 #[test]
1412 fn test_agent_crossover_child_can_infer() {
1413 let config = default_config();
1414 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1415 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1416
1417 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1418 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1419
1420 let mut child: PcActorCritic = PcActorCritic::crossover(
1421 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1422 )
1423 .unwrap();
1424
1425 let input = vec![0.5; 9];
1426 let valid = vec![0, 1, 2, 3, 4];
1427 let (action, _) = child.act(&input, &valid, SelectionMode::Training);
1428 assert!(valid.contains(&action), "Action {action} not in valid set");
1429 }
1430
1431 #[test]
1432 fn test_agent_crossover_child_can_learn() {
1433 let config = default_config();
1434 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1435 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1436
1437 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1438 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1439
1440 let mut child: PcActorCritic = PcActorCritic::crossover(
1441 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1442 )
1443 .unwrap();
1444
1445 let trajectory = make_trajectory(&mut child);
1446 let loss = child.learn(&trajectory);
1447 assert!(loss.is_finite(), "Child learn loss not finite: {loss}");
1448 }
1449
1450 #[test]
1451 fn test_agent_crossover_mismatched_batch_size_error() {
1452 let config = default_config();
1453 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1454 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1455
1456 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1457 let (ac_b, _cc_b) = build_caches_for_agent(&mut agent_b, 30); let (_, cc_b_match) = build_caches_for_agent(&mut agent_b, 50);
1459
1460 let result = PcActorCritic::crossover(
1462 &agent_a,
1463 &agent_b,
1464 &ac_a,
1465 &ac_b,
1466 &cc_a,
1467 &cc_b_match,
1468 0.5,
1469 config,
1470 99,
1471 );
1472 assert!(result.is_err(), "Mismatched actor batch sizes should error");
1473 }
1474
1475 #[test]
1478 fn test_agent_crossover_with_separate_critic_caches() {
1479 let config = default_config();
1480 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1481 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1482
1483 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1484 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1485
1486 let child: PcActorCritic = PcActorCritic::crossover(
1487 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1488 )
1489 .unwrap();
1490
1491 assert_eq!(child.critic.layers.len(), agent_a.critic.layers.len());
1492 }
1493
1494 #[test]
1495 fn test_agent_crossover_critic_uses_own_caches() {
1496 let config = default_config();
1497 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1498 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1499
1500 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1501 let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1502
1503 let child: PcActorCritic = PcActorCritic::crossover(
1504 &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1505 )
1506 .unwrap();
1507
1508 assert_ne!(
1509 child.critic.layers[0].weights.data,
1510 agent_a.critic.layers[0].weights.data
1511 );
1512 assert_ne!(
1513 child.critic.layers[0].weights.data,
1514 agent_b.critic.layers[0].weights.data
1515 );
1516 }
1517
1518 #[test]
1519 fn test_agent_crossover_mismatched_critic_batch_error() {
1520 let config = default_config();
1521 let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1522 let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1523
1524 let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1525 let (ac_b, _) = build_caches_for_agent(&mut agent_b, 50);
1526 let (_, cc_b_small) = build_caches_for_agent(&mut agent_b, 30);
1528
1529 let result = PcActorCritic::crossover(
1530 &agent_a,
1531 &agent_b,
1532 &ac_a,
1533 &ac_b,
1534 &cc_a,
1535 &cc_b_small,
1536 0.5,
1537 config,
1538 99,
1539 );
1540 assert!(
1541 result.is_err(),
1542 "Mismatched critic batch sizes should error"
1543 );
1544 }
1545
1546 #[test]
1549 fn test_activation_cache_accessible_from_crate() {
1550 let _cache: crate::pc_actor_critic::ActivationCache = ActivationCache::new(1);
1552 }
1553
1554 #[test]
1555 fn test_cca_neuron_alignment_accessible_from_crate() {
1556 use crate::linalg::cpu::CpuLinAlg;
1558 use crate::linalg::LinAlg;
1559 let mat = CpuLinAlg::zeros_mat(10, 3);
1560 let _perm = crate::matrix::cca_neuron_alignment::<CpuLinAlg>(&mat, &mat).unwrap();
1561 }
1562}