Skip to main content

pc_rl_core/
pc_actor_critic.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! Integrated PC Actor-Critic agent.
6//!
7//! Combines [`PcActor`] for action selection via predictive coding inference
8//! with [`MlpCritic`] for value estimation. Supports REINFORCE episodic
9//! learning, TD(0) continuous learning, surprise-based scheduling, and
10//! entropy regularization.
11//!
12//! Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
13
14use std::collections::VecDeque;
15
16use rand::rngs::StdRng;
17use serde::{Deserialize, Serialize};
18
19use crate::error::PcError;
20use crate::linalg::cpu::CpuLinAlg;
21use crate::linalg::LinAlg;
22use crate::mlp_critic::{MlpCritic, MlpCriticConfig};
23use crate::pc_actor::{InferResult, PcActor, PcActorConfig, SelectionMode};
24
25/// Default surprise buffer size for adaptive surprise.
26fn default_surprise_buffer_size() -> usize {
27    100
28}
29
30/// Configuration for the integrated PC Actor-Critic agent.
31///
32/// # Examples
33///
34/// ```
35/// use pc_rl_core::activation::Activation;
36/// use pc_rl_core::layer::LayerDef;
37/// use pc_rl_core::mlp_critic::MlpCriticConfig;
38/// use pc_rl_core::pc_actor::PcActorConfig;
39/// use pc_rl_core::pc_actor_critic::PcActorCriticConfig;
40///
41/// let config = PcActorCriticConfig {
42///     actor: PcActorConfig {
43///         input_size: 9,
44///         hidden_layers: vec![LayerDef { size: 18, activation: Activation::Tanh }],
45///         output_size: 9,
46///         output_activation: Activation::Tanh,
47///         alpha: 0.1, tol: 0.01, min_steps: 1, max_steps: 20,
48///         lr_weights: 0.01, synchronous: true, temperature: 1.0,
49///         local_lambda: 1.0,
50///         residual: false,
51///         rezero_init: 0.001,
52///     },
53///     critic: MlpCriticConfig {
54///         input_size: 27,
55///         hidden_layers: vec![LayerDef { size: 36, activation: Activation::Tanh }],
56///         output_activation: Activation::Linear,
57///         lr: 0.005,
58///     },
59///     gamma: 0.95,
60///     surprise_low: 0.02,
61///     surprise_high: 0.15,
62///     adaptive_surprise: false,
63///     surprise_buffer_size: 100,
64///     entropy_coeff: 0.01,
65/// };
66/// ```
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct PcActorCriticConfig {
69    /// Actor (PC network) configuration.
70    pub actor: PcActorConfig,
71    /// Critic (MLP value function) configuration.
72    pub critic: MlpCriticConfig,
73    /// Discount factor for computing returns.
74    pub gamma: f64,
75    /// Surprise threshold below which learning rate is scaled down.
76    pub surprise_low: f64,
77    /// Surprise threshold above which learning rate is scaled up.
78    pub surprise_high: f64,
79    /// Whether to adaptively recalibrate surprise thresholds.
80    pub adaptive_surprise: bool,
81    /// Maximum number of surprise scores in the adaptive buffer.
82    /// Only used when `adaptive_surprise` is true. Default: 100.
83    #[serde(default = "default_surprise_buffer_size")]
84    pub surprise_buffer_size: usize,
85    /// Entropy regularization coefficient.
86    pub entropy_coeff: f64,
87}
88
89/// A single step in a trajectory collected during an episode.
90///
91/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
92#[derive(Debug, Clone)]
93pub struct TrajectoryStep<L: LinAlg = CpuLinAlg> {
94    /// Board state input vector.
95    pub input: L::Vector,
96    /// Concatenated hidden layer activations from inference.
97    pub latent_concat: L::Vector,
98    /// Converged output logits from inference.
99    pub y_conv: L::Vector,
100    /// Per-layer hidden state activations from inference (for backprop).
101    pub hidden_states: Vec<L::Vector>,
102    /// Per-layer prediction errors from the PC inference loop.
103    pub prediction_errors: Vec<L::Vector>,
104    /// Per-layer tanh components for residual layers (for correct backward pass).
105    pub tanh_components: Vec<Option<L::Vector>>,
106    /// Action taken at this step.
107    pub action: usize,
108    /// Valid actions at this step (needed for masked softmax).
109    pub valid_actions: Vec<usize>,
110    /// Reward received after taking the action.
111    pub reward: f64,
112    /// Prediction error from inference.
113    pub surprise_score: f64,
114    /// Number of inference steps used.
115    pub steps_used: usize,
116}
117
118/// Cache for hidden layer activations captured during inference.
119///
120/// Used by CCA neuron alignment during crossover to compare functional
121/// neuron responses between parent networks. Activations are recorded
122/// during normal fitness evaluation at zero additional compute cost.
123///
124/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
125///
126/// # Examples
127///
128/// ```
129/// use pc_rl_core::pc_actor_critic::ActivationCache;
130/// use pc_rl_core::linalg::cpu::CpuLinAlg;
131///
132/// let cache: ActivationCache<CpuLinAlg> = ActivationCache::new(2);
133/// assert_eq!(cache.batch_size(), 0);
134/// assert_eq!(cache.num_layers(), 2);
135/// ```
136#[derive(Debug, Clone)]
137pub struct ActivationCache<L: LinAlg = CpuLinAlg> {
138    /// activations[layer_idx][batch_sample_idx] = neuron activations.
139    layers: Vec<Vec<L::Vector>>,
140}
141
142impl<L: LinAlg> ActivationCache<L> {
143    /// Creates a new empty cache with the given number of hidden layers.
144    ///
145    /// # Arguments
146    ///
147    /// * `num_layers` - Number of hidden layers to track.
148    pub fn new(num_layers: usize) -> Self {
149        Self {
150            layers: (0..num_layers).map(|_| Vec::new()).collect(),
151        }
152    }
153
154    /// Returns the number of samples recorded so far.
155    pub fn batch_size(&self) -> usize {
156        self.layers.first().map_or(0, |l| l.len())
157    }
158
159    /// Returns the number of hidden layers in the cache.
160    pub fn num_layers(&self) -> usize {
161        self.layers.len()
162    }
163
164    /// Records hidden layer activations from a single inference.
165    ///
166    /// # Arguments
167    ///
168    /// * `hidden_states` - Per-layer activations from `InferResult::hidden_states`.
169    pub fn record(&mut self, hidden_states: &[L::Vector]) {
170        for (layer, state) in self.layers.iter_mut().zip(hidden_states.iter()) {
171            layer.push(state.clone());
172        }
173    }
174
175    /// Returns the recorded activations for a given layer.
176    ///
177    /// # Arguments
178    ///
179    /// * `layer_idx` - Index of the hidden layer.
180    pub fn layer(&self, layer_idx: usize) -> &[L::Vector] {
181        &self.layers[layer_idx]
182    }
183}
184
185/// Integrated PC Actor-Critic agent.
186///
187/// Combines a predictive coding actor with an MLP critic for
188/// reinforcement learning with surprise-based scheduling.
189///
190/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
191#[derive(Debug)]
192pub struct PcActorCritic<L: LinAlg = CpuLinAlg> {
193    /// The PC actor network.
194    pub(crate) actor: PcActor<L>,
195    /// The MLP critic (value function).
196    pub(crate) critic: MlpCritic<L>,
197    /// Agent configuration.
198    pub config: PcActorCriticConfig,
199    /// Random number generator for action selection.
200    rng: StdRng,
201    /// Circular buffer of recent surprise scores for adaptive thresholds.
202    surprise_buffer: VecDeque<f64>,
203}
204
205impl<L: LinAlg> PcActorCritic<L> {
206    /// Creates a new PC Actor-Critic agent.
207    ///
208    /// # Arguments
209    ///
210    /// * `config` - Agent configuration with actor, critic, and learning parameters.
211    /// * `seed` - Random seed for reproducibility.
212    /// # Errors
213    ///
214    /// Returns `PcError::ConfigValidation` if gamma is out of `[0.0, 1.0]`,
215    /// or if actor/critic config is invalid.
216    pub fn new(config: PcActorCriticConfig, seed: u64) -> Result<Self, PcError> {
217        if !(0.0..=1.0).contains(&config.gamma) {
218            return Err(PcError::ConfigValidation(format!(
219                "gamma must be in [0.0, 1.0], got {}",
220                config.gamma
221            )));
222        }
223
224        use rand::SeedableRng;
225        let mut rng = StdRng::seed_from_u64(seed);
226        let actor = PcActor::<L>::new(config.actor.clone(), &mut rng)?;
227        let critic = MlpCritic::<L>::new(config.critic.clone(), &mut rng)?;
228        Ok(Self {
229            actor,
230            critic,
231            config,
232            rng,
233            surprise_buffer: VecDeque::new(),
234        })
235    }
236
237    /// Creates a child agent by crossing over two parent agents using CCA neuron alignment.
238    ///
239    /// Delegates to `PcActor::crossover` and `MlpCritic::crossover`, converting
240    /// activation caches to the matrix format expected by CCA alignment.
241    ///
242    /// # Arguments
243    ///
244    /// * `parent_a` - First parent agent (reference, typically higher fitness).
245    /// * `parent_b` - Second parent agent.
246    /// * `cache_a` - Activation cache for parent A on the reference batch.
247    /// * `cache_b` - Activation cache for parent B on the reference batch.
248    /// * `alpha` - Blending weight: 1.0 = all A, 0.0 = all B.
249    /// * `child_config` - Configuration for the child agent.
250    /// * `seed` - Random seed for the child's RNG.
251    ///
252    /// # Errors
253    ///
254    /// Returns `PcError::DimensionMismatch` if activation caches have different
255    /// batch sizes. Returns `PcError::ConfigValidation` if child config is invalid.
256    #[allow(clippy::too_many_arguments)]
257    pub fn crossover(
258        parent_a: &PcActorCritic<L>,
259        parent_b: &PcActorCritic<L>,
260        actor_cache_a: &ActivationCache<L>,
261        actor_cache_b: &ActivationCache<L>,
262        critic_cache_a: &ActivationCache<L>,
263        critic_cache_b: &ActivationCache<L>,
264        alpha: f64,
265        child_config: PcActorCriticConfig,
266        seed: u64,
267    ) -> Result<Self, PcError> {
268        // Validate actor batch sizes match
269        if actor_cache_a.batch_size() != actor_cache_b.batch_size() {
270            return Err(PcError::DimensionMismatch {
271                expected: actor_cache_a.batch_size(),
272                got: actor_cache_b.batch_size(),
273                context: "actor activation cache batch sizes must match for crossover",
274            });
275        }
276        // Validate critic batch sizes match
277        if critic_cache_a.batch_size() != critic_cache_b.batch_size() {
278            return Err(PcError::DimensionMismatch {
279                expected: critic_cache_a.batch_size(),
280                got: critic_cache_b.batch_size(),
281                context: "critic activation cache batch sizes must match for crossover",
282            });
283        }
284
285        // Convert caches to matrices [batch × neurons] for CCA
286        let actor_cache_mats_a = cache_to_matrices::<L>(actor_cache_a);
287        let actor_cache_mats_b = cache_to_matrices::<L>(actor_cache_b);
288        let critic_cache_mats_a = cache_to_matrices::<L>(critic_cache_a);
289        let critic_cache_mats_b = cache_to_matrices::<L>(critic_cache_b);
290
291        use rand::SeedableRng;
292        let mut rng = StdRng::seed_from_u64(seed);
293
294        // Crossover actor with actor-specific caches
295        let actor = PcActor::<L>::crossover(
296            &parent_a.actor,
297            &parent_b.actor,
298            &actor_cache_mats_a,
299            &actor_cache_mats_b,
300            alpha,
301            child_config.actor.clone(),
302            &mut rng,
303        )?;
304
305        // Crossover critic with critic-specific caches
306        let critic = MlpCritic::<L>::crossover(
307            &parent_a.critic,
308            &parent_b.critic,
309            &critic_cache_mats_a,
310            &critic_cache_mats_b,
311            alpha,
312            child_config.critic.clone(),
313            &mut rng,
314        )?;
315
316        Ok(Self {
317            actor,
318            critic,
319            config: child_config,
320            rng,
321            surprise_buffer: VecDeque::new(),
322        })
323    }
324
325    /// Reconstructs an agent from pre-built components (used by serializer).
326    ///
327    /// # Arguments
328    ///
329    /// * `config` - Agent configuration.
330    /// * `actor` - Pre-built PC actor with loaded weights.
331    /// * `critic` - Pre-built MLP critic with loaded weights.
332    /// * `rng` - Random number generator.
333    pub fn from_parts(
334        config: PcActorCriticConfig,
335        actor: PcActor<L>,
336        critic: MlpCritic<L>,
337        rng: StdRng,
338    ) -> Self {
339        Self {
340            actor,
341            critic,
342            config,
343            rng,
344            surprise_buffer: VecDeque::new(),
345        }
346    }
347
348    /// Runs PC inference without selecting an action or modifying RNG state.
349    ///
350    /// Use this when you only need the inference result (e.g., for TD(0)
351    /// next-state evaluation) without side effects.
352    ///
353    /// # Arguments
354    ///
355    /// * `input` - Board state vector.
356    ///
357    /// # Panics
358    ///
359    /// Panics if `input.len() != config.actor.input_size`.
360    pub fn infer(&self, input: &[f64]) -> InferResult<L> {
361        self.actor.infer(input)
362    }
363
364    /// Selects an action given the current state.
365    ///
366    /// Runs PC inference on the input, then selects an action using the
367    /// converged logits and the specified selection mode.
368    ///
369    /// # Arguments
370    ///
371    /// * `input` - Board state vector.
372    /// * `valid_actions` - Indices of legal actions.
373    /// * `mode` - Training (stochastic) or Play (deterministic).
374    ///
375    /// # Panics
376    ///
377    /// Panics if `valid_actions` is empty.
378    pub fn act(
379        &mut self,
380        input: &[f64],
381        valid_actions: &[usize],
382        mode: SelectionMode,
383    ) -> (usize, InferResult<L>) {
384        let infer_result = self.actor.infer(input);
385        let action =
386            self.actor
387                .select_action(&infer_result.y_conv, valid_actions, mode, &mut self.rng);
388        (action, infer_result)
389    }
390
391    /// Learns from a complete episode trajectory using REINFORCE with baseline.
392    ///
393    /// Empty trajectory returns 0.0 without modifying weights. Otherwise computes
394    /// discounted returns, advantages, and updates both actor and critic.
395    ///
396    /// # Arguments
397    ///
398    /// * `trajectory` - Sequence of steps from an episode.
399    ///
400    /// # Returns
401    ///
402    /// Average critic loss over the trajectory.
403    pub fn learn(&mut self, trajectory: &[TrajectoryStep<L>]) -> f64 {
404        if trajectory.is_empty() {
405            return 0.0;
406        }
407
408        let n = trajectory.len();
409
410        // Compute discounted returns backward
411        let mut returns = vec![0.0; n];
412        returns[n - 1] = trajectory[n - 1].reward;
413        for t in (0..n - 1).rev() {
414            returns[t] = trajectory[t].reward + self.config.gamma * returns[t + 1];
415        }
416
417        let mut total_loss = 0.0;
418
419        for (t, step) in trajectory.iter().enumerate() {
420            // Build critic input: concat(input, latent_concat)
421            let input_vec = L::vec_to_vec(&step.input);
422            let latent_vec = L::vec_to_vec(&step.latent_concat);
423            let mut critic_input = input_vec.clone();
424            critic_input.extend_from_slice(&latent_vec);
425
426            // V(s)
427            let value = self.critic.forward(&critic_input);
428            let advantage = returns[t] - value;
429
430            // Update critic toward discounted return
431            let loss = self.critic.update(&critic_input, returns[t]);
432            total_loss += loss;
433
434            // Policy gradient
435            let y_conv_vec = L::vec_to_vec(&step.y_conv);
436            let scaled: Vec<f64> = y_conv_vec
437                .iter()
438                .map(|&v| v / self.actor.config.temperature)
439                .collect();
440            let scaled_l = L::vec_from_slice(&scaled);
441            let pi_l = L::softmax_masked(&scaled_l, &step.valid_actions);
442            let pi = L::vec_to_vec(&pi_l);
443
444            let mut delta = vec![0.0; pi.len()];
445            for &i in &step.valid_actions {
446                delta[i] = pi[i];
447            }
448            delta[step.action] -= 1.0;
449
450            // Scale by advantage
451            for &i in &step.valid_actions {
452                delta[i] *= advantage;
453            }
454
455            // Entropy regularization
456            for &i in &step.valid_actions {
457                let log_pi = (pi[i].max(1e-10)).ln();
458                delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
459            }
460
461            // Compute surprise scale and update actor using stored hidden_states
462            let s_scale = self.surprise_scale(step.surprise_score);
463
464            let stored_infer = InferResult {
465                y_conv: step.y_conv.clone(),
466                latent_concat: step.latent_concat.clone(),
467                hidden_states: step.hidden_states.clone(),
468                prediction_errors: step.prediction_errors.clone(),
469                surprise_score: step.surprise_score,
470                steps_used: step.steps_used,
471                converged: false,
472                tanh_components: step.tanh_components.clone(),
473            };
474            self.actor
475                .update_weights(&delta, &stored_infer, &input_vec, s_scale);
476
477            // Push surprise to adaptive buffer
478            if self.config.adaptive_surprise {
479                self.push_surprise(step.surprise_score);
480            }
481        }
482
483        total_loss / n as f64
484    }
485
486    /// Single-step TD(0) continuous learning.
487    ///
488    /// # Arguments
489    ///
490    /// * `input` - Current state.
491    /// * `infer` - Inference result from `act` at current state.
492    /// * `action` - Action taken.
493    /// * `valid_actions` - Valid actions at current state.
494    /// * `reward` - Reward received.
495    /// * `next_input` - Next state.
496    /// * `next_infer` - Inference result from `act` at next state.
497    /// * `terminal` - Whether the episode ended.
498    ///
499    /// # Returns
500    ///
501    /// Critic loss for this step.
502    #[allow(clippy::too_many_arguments)]
503    pub fn learn_continuous(
504        &mut self,
505        input: &[f64],
506        infer: &InferResult<L>,
507        action: usize,
508        valid_actions: &[usize],
509        reward: f64,
510        next_input: &[f64],
511        next_infer: &InferResult<L>,
512        terminal: bool,
513    ) -> f64 {
514        // Build critic inputs
515        let latent_vec = L::vec_to_vec(&infer.latent_concat);
516        let mut critic_input = input.to_vec();
517        critic_input.extend_from_slice(&latent_vec);
518
519        let next_latent_vec = L::vec_to_vec(&next_infer.latent_concat);
520        let mut next_critic_input = next_input.to_vec();
521        next_critic_input.extend_from_slice(&next_latent_vec);
522
523        let v_s = self.critic.forward(&critic_input);
524        let v_next = if terminal {
525            0.0
526        } else {
527            self.critic.forward(&next_critic_input)
528        };
529
530        let target = reward
531            + if terminal {
532                0.0
533            } else {
534                self.config.gamma * v_next
535            };
536        let td_error = target - v_s;
537
538        // Update critic
539        let loss = self.critic.update(&critic_input, target);
540
541        // Policy gradient (same formula as learn, but scaled by td_error)
542        let y_conv_vec = L::vec_to_vec(&infer.y_conv);
543        let scaled: Vec<f64> = y_conv_vec
544            .iter()
545            .map(|&v| v / self.actor.config.temperature)
546            .collect();
547        let scaled_l = L::vec_from_slice(&scaled);
548        let pi_l = L::softmax_masked(&scaled_l, valid_actions);
549        let pi = L::vec_to_vec(&pi_l);
550
551        let mut delta = vec![0.0; pi.len()];
552        for &i in valid_actions {
553            delta[i] = pi[i];
554        }
555        delta[action] -= 1.0;
556
557        for &i in valid_actions {
558            delta[i] *= td_error;
559        }
560
561        // Entropy regularization
562        for &i in valid_actions {
563            let log_pi = (pi[i].max(1e-10)).ln();
564            delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
565        }
566
567        let s_scale = self.surprise_scale(infer.surprise_score);
568        self.actor.update_weights(&delta, infer, input, s_scale);
569
570        if self.config.adaptive_surprise {
571            self.push_surprise(infer.surprise_score);
572        }
573
574        loss
575    }
576
577    /// Computes the learning rate scale factor based on surprise score.
578    ///
579    /// - surprise <= low → 0.1
580    /// - surprise >= high → 2.0
581    /// - Between → linear interpolation from 0.1 to 2.0
582    ///
583    /// If adaptive surprise is enabled and the buffer has >= 10 entries,
584    /// thresholds are dynamically recomputed from the buffer statistics.
585    pub fn surprise_scale(&self, surprise: f64) -> f64 {
586        let (low, high) = if self.config.adaptive_surprise && self.surprise_buffer.len() >= 10 {
587            let mean = self.surprise_buffer.iter().sum::<f64>() / self.surprise_buffer.len() as f64;
588            let variance = self
589                .surprise_buffer
590                .iter()
591                .map(|&s| (s - mean) * (s - mean))
592                .sum::<f64>()
593                / self.surprise_buffer.len() as f64;
594            let std = variance.sqrt();
595            let lo = (mean - 0.5 * std).max(0.0);
596            let hi = mean + 1.5 * std;
597            (lo, hi)
598        } else {
599            (self.config.surprise_low, self.config.surprise_high)
600        };
601
602        if surprise <= low {
603            0.1
604        } else if surprise >= high {
605            2.0
606        } else {
607            // Linear interpolation
608            let t = (surprise - low) / (high - low);
609            0.1 + t * (2.0 - 0.1)
610        }
611    }
612
613    /// Pushes a surprise score into the adaptive buffer (circular).
614    fn push_surprise(&mut self, surprise: f64) {
615        if self.surprise_buffer.len() >= self.config.surprise_buffer_size {
616            self.surprise_buffer.pop_front();
617        }
618        self.surprise_buffer.push_back(surprise);
619    }
620}
621
622/// Converts an `ActivationCache` into a vector of matrices `[batch × neurons]`,
623/// one per hidden layer, suitable for CCA alignment.
624fn cache_to_matrices<L: LinAlg>(cache: &ActivationCache<L>) -> Vec<L::Matrix> {
625    let num_layers = cache.num_layers();
626    let batch_size = cache.batch_size();
627    let mut matrices = Vec::with_capacity(num_layers);
628
629    for layer_idx in 0..num_layers {
630        let samples = cache.layer(layer_idx);
631        if samples.is_empty() {
632            matrices.push(L::zeros_mat(0, 0));
633            continue;
634        }
635        let n_neurons = L::vec_len(&samples[0]);
636        let mut mat = L::zeros_mat(batch_size, n_neurons);
637        for (r, sample) in samples.iter().enumerate() {
638            for c in 0..n_neurons {
639                L::mat_set(&mut mat, r, c, L::vec_get(sample, c));
640            }
641        }
642        matrices.push(mat);
643    }
644
645    matrices
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651    use crate::activation::Activation;
652    use crate::layer::LayerDef;
653    use crate::pc_actor::SelectionMode;
654
655    fn default_config() -> PcActorCriticConfig {
656        PcActorCriticConfig {
657            actor: PcActorConfig {
658                input_size: 9,
659                hidden_layers: vec![LayerDef {
660                    size: 18,
661                    activation: Activation::Tanh,
662                }],
663                output_size: 9,
664                output_activation: Activation::Tanh,
665                alpha: 0.1,
666                tol: 0.01,
667                min_steps: 1,
668                max_steps: 20,
669                lr_weights: 0.01,
670                synchronous: true,
671                temperature: 1.0,
672                local_lambda: 1.0,
673                residual: false,
674                rezero_init: 0.001,
675            },
676            critic: MlpCriticConfig {
677                input_size: 27,
678                hidden_layers: vec![LayerDef {
679                    size: 36,
680                    activation: Activation::Tanh,
681                }],
682                output_activation: Activation::Linear,
683                lr: 0.005,
684            },
685            gamma: 0.95,
686            surprise_low: 0.02,
687            surprise_high: 0.15,
688            adaptive_surprise: false,
689            surprise_buffer_size: 100,
690            entropy_coeff: 0.01,
691        }
692    }
693
694    fn make_agent() -> PcActorCritic {
695        let agent: PcActorCritic = PcActorCritic::new(default_config(), 42).unwrap();
696        agent
697    }
698
699    fn make_trajectory(agent: &mut PcActorCritic) -> Vec<TrajectoryStep> {
700        let input = vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5];
701        let valid = vec![2, 7];
702        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
703        vec![TrajectoryStep {
704            input,
705            latent_concat: infer.latent_concat,
706            y_conv: infer.y_conv,
707            hidden_states: infer.hidden_states,
708            prediction_errors: infer.prediction_errors,
709            tanh_components: infer.tanh_components,
710            action,
711            valid_actions: valid,
712            reward: 1.0,
713            surprise_score: infer.surprise_score,
714            steps_used: infer.steps_used,
715        }]
716    }
717
718    // ── learn tests ───────────────────────────────────────────────
719
720    #[test]
721    fn test_learn_empty_returns_zero_without_modifying_weights() {
722        let mut agent: PcActorCritic = make_agent();
723        let w_before = agent.actor.layers[0].weights.data.clone();
724        let cw_before = agent.critic.layers[0].weights.data.clone();
725        let loss = agent.learn(&[]);
726        assert_eq!(loss, 0.0);
727        assert_eq!(agent.actor.layers[0].weights.data, w_before);
728        assert_eq!(agent.critic.layers[0].weights.data, cw_before);
729    }
730
731    #[test]
732    fn test_learn_updates_actor_weights() {
733        let mut agent: PcActorCritic = make_agent();
734        let trajectory = make_trajectory(&mut agent);
735        let w_before = agent.actor.layers[0].weights.data.clone();
736        let _ = agent.learn(&trajectory);
737        assert_ne!(agent.actor.layers[0].weights.data, w_before);
738    }
739
740    #[test]
741    fn test_learn_updates_critic_weights() {
742        let mut agent: PcActorCritic = make_agent();
743        let trajectory = make_trajectory(&mut agent);
744        let w_before = agent.critic.layers[0].weights.data.clone();
745        let _ = agent.learn(&trajectory);
746        assert_ne!(agent.critic.layers[0].weights.data, w_before);
747    }
748
749    #[test]
750    fn test_learn_returns_finite_nonneg_loss() {
751        let mut agent: PcActorCritic = make_agent();
752        let trajectory = make_trajectory(&mut agent);
753        let loss = agent.learn(&trajectory);
754        assert!(loss.is_finite(), "Loss {loss} is not finite");
755        assert!(loss >= 0.0, "Loss {loss} is negative");
756    }
757
758    #[test]
759    fn test_learn_single_step_trajectory() {
760        let mut agent: PcActorCritic = make_agent();
761        let input = vec![0.5; 9];
762        let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
763        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
764        let trajectory = vec![TrajectoryStep {
765            input,
766            latent_concat: infer.latent_concat,
767            y_conv: infer.y_conv,
768            hidden_states: infer.hidden_states,
769            prediction_errors: infer.prediction_errors,
770            tanh_components: infer.tanh_components,
771            action,
772            valid_actions: valid,
773            reward: -1.0,
774            surprise_score: infer.surprise_score,
775            steps_used: infer.steps_used,
776        }];
777        let loss = agent.learn(&trajectory);
778        assert!(loss.is_finite());
779    }
780
781    #[test]
782    fn test_learn_multi_step_uses_stored_hidden_states() {
783        // Build a 3-step trajectory to exercise multi-step learning
784        let mut agent: PcActorCritic = make_agent();
785        let inputs = [
786            vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5],
787            vec![0.5, 0.5, -1.0, 0.0, 1.0, -0.5, 0.0, -1.0, 0.5],
788            vec![-1.0, 0.0, 1.0, -0.5, 0.5, 0.0, 1.0, -1.0, -0.5],
789        ];
790        let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
791
792        let mut trajectory = Vec::new();
793        for (i, inp) in inputs.iter().enumerate() {
794            let (action, infer) = agent.act(inp, &valid, SelectionMode::Training);
795            trajectory.push(TrajectoryStep {
796                input: inp.clone(),
797                latent_concat: infer.latent_concat,
798                y_conv: infer.y_conv,
799                hidden_states: infer.hidden_states,
800                prediction_errors: infer.prediction_errors,
801                tanh_components: infer.tanh_components,
802                action,
803                valid_actions: valid.clone(),
804                reward: if i == 2 { 1.0 } else { 0.0 },
805                surprise_score: infer.surprise_score,
806                steps_used: infer.steps_used,
807            });
808        }
809
810        let loss = agent.learn(&trajectory);
811        assert!(
812            loss.is_finite(),
813            "Multi-step learn should produce finite loss"
814        );
815        assert!(loss >= 0.0);
816    }
817
818    // ── learn_continuous tests ────────────────────────────────────
819
820    #[test]
821    fn test_learn_continuous_nonterminal_uses_next_value() {
822        let mut agent: PcActorCritic = make_agent();
823        let input = vec![0.5; 9];
824        let next_input = vec![-0.5; 9];
825        let valid = vec![0, 1, 2];
826        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
827        let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
828
829        // Non-terminal: should incorporate next value
830        let loss = agent.learn_continuous(
831            &input,
832            &infer,
833            action,
834            &valid,
835            0.5,
836            &next_input,
837            &next_infer,
838            false,
839        );
840        assert!(loss.is_finite());
841    }
842
843    #[test]
844    fn test_learn_continuous_terminal_uses_reward_only() {
845        let mut agent: PcActorCritic = make_agent();
846        let input = vec![0.5; 9];
847        let next_input = vec![0.0; 9];
848        let valid = vec![0, 1, 2];
849        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
850        let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
851
852        // Terminal: target = reward only (no gamma * V(s'))
853        let loss = agent.learn_continuous(
854            &input,
855            &infer,
856            action,
857            &valid,
858            1.0,
859            &next_input,
860            &next_infer,
861            true,
862        );
863        assert!(loss.is_finite());
864    }
865
866    #[test]
867    fn test_learn_continuous_terminal_and_nonterminal_produce_different_updates() {
868        // Create two identical agents
869        let config = default_config();
870        let mut agent_term: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
871        let mut agent_nonterm: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
872
873        let input = vec![0.5; 9];
874        let next_input = vec![-0.5; 9];
875        let valid = vec![0, 1, 2];
876
877        // Use identical actions and inferences
878        let (action, infer) = agent_term.act(&input, &valid, SelectionMode::Training);
879        let (_, next_infer) = agent_term.act(&next_input, &valid, SelectionMode::Training);
880
881        // Clone infer for the non-terminal agent (same starting point)
882        let (action2, infer2) = agent_nonterm.act(&input, &valid, SelectionMode::Training);
883        let (_, next_infer2) = agent_nonterm.act(&next_input, &valid, SelectionMode::Training);
884
885        // Terminal update
886        let loss_term = agent_term.learn_continuous(
887            &input,
888            &infer,
889            action,
890            &valid,
891            1.0,
892            &next_input,
893            &next_infer,
894            true,
895        );
896
897        // Non-terminal update with same reward
898        let loss_nonterm = agent_nonterm.learn_continuous(
899            &input,
900            &infer2,
901            action2,
902            &valid,
903            1.0,
904            &next_input,
905            &next_infer2,
906            false,
907        );
908
909        // The losses should differ because terminal uses target=reward
910        // while non-terminal uses target=reward+gamma*V(s')
911        assert!(
912            (loss_term - loss_nonterm).abs() > 1e-15,
913            "Terminal and non-terminal should produce different losses: {loss_term} vs {loss_nonterm}"
914        );
915    }
916
917    #[test]
918    fn test_learn_continuous_updates_actor() {
919        let mut agent: PcActorCritic = make_agent();
920        let input = vec![0.5; 9];
921        let next_input = vec![-0.5; 9];
922        let valid = vec![0, 1, 2];
923        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
924        let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
925        let w_before = agent.actor.layers[0].weights.data.clone();
926        let _ = agent.learn_continuous(
927            &input,
928            &infer,
929            action,
930            &valid,
931            1.0,
932            &next_input,
933            &next_infer,
934            false,
935        );
936        assert_ne!(agent.actor.layers[0].weights.data, w_before);
937    }
938
939    // ── surprise_scale tests ─────────────────────────────────────
940
941    #[test]
942    fn test_surprise_scale_below_low() {
943        let agent: PcActorCritic = make_agent();
944        let scale = agent.surprise_scale(0.01); // below low=0.02
945        assert!((scale - 0.1).abs() < 1e-12, "Expected 0.1, got {scale}");
946    }
947
948    #[test]
949    fn test_surprise_scale_above_high() {
950        let agent: PcActorCritic = make_agent();
951        let scale = agent.surprise_scale(0.20); // above high=0.15
952        assert!((scale - 2.0).abs() < 1e-12, "Expected 2.0, got {scale}");
953    }
954
955    #[test]
956    fn test_surprise_scale_midpoint_in_range() {
957        let agent: PcActorCritic = make_agent();
958        let midpoint = (0.02 + 0.15) / 2.0;
959        let scale = agent.surprise_scale(midpoint);
960        assert!(
961            scale > 0.1 && scale < 2.0,
962            "Midpoint scale {scale} out of range"
963        );
964    }
965
966    #[test]
967    fn test_surprise_scale_monotone_increasing() {
968        let agent: PcActorCritic = make_agent();
969        let s1 = agent.surprise_scale(0.01);
970        let s2 = agent.surprise_scale(0.05);
971        let s3 = agent.surprise_scale(0.10);
972        let s4 = agent.surprise_scale(0.20);
973        assert!(s1 <= s2, "s1={s1} > s2={s2}");
974        assert!(s2 <= s3, "s2={s2} > s3={s3}");
975        assert!(s3 <= s4, "s3={s3} > s4={s4}");
976    }
977
978    #[test]
979    fn test_adaptive_surprise_recalibrates_thresholds_after_many_episodes() {
980        let mut config = default_config();
981        config.adaptive_surprise = true;
982        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
983
984        // Fill buffer with varied surprise scores to get nonzero std
985        for i in 0..15 {
986            agent.push_surprise(0.1 + 0.02 * i as f64);
987        }
988
989        // mean ≈ 0.24, std ≈ 0.089
990        // adaptive low = max(0, mean - 0.5*std) ≈ 0.196
991        // adaptive high = mean + 1.5*std ≈ 0.373
992        // These differ from the static defaults (0.02, 0.15)
993
994        // Something well below adaptive low should get 0.1
995        let scale_low = agent.surprise_scale(0.0);
996        assert!(
997            (scale_low - 0.1).abs() < 1e-12,
998            "Expected 0.1 below adaptive low: got {scale_low}"
999        );
1000
1001        // Something well above adaptive high should get 2.0
1002        let scale_high = agent.surprise_scale(1.0);
1003        assert!(
1004            (scale_high - 2.0).abs() < 1e-12,
1005            "Expected 2.0 above adaptive high: got {scale_high}"
1006        );
1007
1008        // Something at the mean should be between 0.1 and 2.0
1009        let scale_mid = agent.surprise_scale(0.24);
1010        assert!(
1011            scale_mid > 0.1 && scale_mid < 2.0,
1012            "Expected interpolated value at mean, got {scale_mid}"
1013        );
1014    }
1015
1016    #[test]
1017    fn test_entropy_regularization_prevents_policy_collapse() {
1018        // With entropy regularization, repeated learning on same trajectory
1019        // should keep the policy from collapsing to a single action
1020        let mut config = default_config();
1021        config.entropy_coeff = 0.1; // Strong entropy
1022        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1023
1024        let input = vec![0.5; 9];
1025        let valid: Vec<usize> = (0..9).collect();
1026
1027        // Train many times on same trajectory
1028        for _ in 0..20 {
1029            let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
1030            let trajectory = vec![TrajectoryStep {
1031                input: input.clone(),
1032                latent_concat: infer.latent_concat,
1033                y_conv: infer.y_conv,
1034                hidden_states: infer.hidden_states,
1035                prediction_errors: infer.prediction_errors,
1036                tanh_components: infer.tanh_components,
1037                action,
1038                valid_actions: valid.clone(),
1039                reward: 1.0,
1040                surprise_score: infer.surprise_score,
1041                steps_used: infer.steps_used,
1042            }];
1043            let _ = agent.learn(&trajectory);
1044        }
1045
1046        // Check that policy is not collapsed (multiple actions selected over 50 trials)
1047        let mut seen = std::collections::HashSet::new();
1048        for _ in 0..50 {
1049            let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1050            seen.insert(action);
1051        }
1052        assert!(
1053            seen.len() > 1,
1054            "Entropy regularization should prevent collapse to single action, but only saw {:?}",
1055            seen
1056        );
1057    }
1058
1059    // ── act tests ─────────────────────────────────────────────────
1060
1061    #[test]
1062    fn test_act_returns_valid_action() {
1063        let mut agent: PcActorCritic = make_agent();
1064        let input = vec![0.5; 9];
1065        let valid = vec![1, 3, 5, 7];
1066        for _ in 0..20 {
1067            let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1068            assert!(valid.contains(&action), "Action {action} not in valid set");
1069        }
1070    }
1071
1072    #[test]
1073    #[should_panic]
1074    fn test_act_empty_valid_panics() {
1075        let mut agent: PcActorCritic = make_agent();
1076        let input = vec![0.5; 9];
1077        let _ = agent.act(&input, &[], SelectionMode::Training);
1078    }
1079
1080    // ── learning diagnostic test ──────────────────────────────
1081
1082    #[test]
1083    fn test_learn_improves_policy_for_rewarded_action() {
1084        // Linear output so logits are unbounded
1085        let config = PcActorCriticConfig {
1086            actor: PcActorConfig {
1087                input_size: 9,
1088                hidden_layers: vec![LayerDef {
1089                    size: 18,
1090                    activation: Activation::Tanh,
1091                }],
1092                output_size: 9,
1093                output_activation: Activation::Linear,
1094                alpha: 0.1,
1095                tol: 0.01,
1096                min_steps: 1,
1097                max_steps: 5,
1098                lr_weights: 0.01,
1099                synchronous: true,
1100                temperature: 1.0,
1101                local_lambda: 1.0,
1102                residual: false,
1103                rezero_init: 0.001,
1104            },
1105            critic: MlpCriticConfig {
1106                input_size: 27,
1107                hidden_layers: vec![LayerDef {
1108                    size: 36,
1109                    activation: Activation::Tanh,
1110                }],
1111                output_activation: Activation::Linear,
1112                lr: 0.005,
1113            },
1114            gamma: 0.99,
1115            surprise_low: 0.02,
1116            surprise_high: 0.15,
1117            adaptive_surprise: false,
1118            surprise_buffer_size: 100,
1119            entropy_coeff: 0.0, // no entropy to isolate gradient effect
1120        };
1121        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1122
1123        let input = vec![0.0; 9];
1124        let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
1125        let target_action = 4; // center
1126
1127        // Repeatedly reward action 4
1128        for _ in 0..200 {
1129            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1130            let trajectory = vec![TrajectoryStep {
1131                input: input.clone(),
1132                latent_concat: infer.latent_concat,
1133                y_conv: infer.y_conv,
1134                hidden_states: infer.hidden_states,
1135                prediction_errors: infer.prediction_errors,
1136                tanh_components: infer.tanh_components,
1137                action: target_action,
1138                valid_actions: valid.clone(),
1139                reward: 1.0,
1140                surprise_score: infer.surprise_score,
1141                steps_used: infer.steps_used,
1142            }];
1143            agent.learn(&trajectory);
1144        }
1145
1146        // After 200 episodes always rewarding action 4, it should be the
1147        // preferred action in Play mode (deterministic argmax)
1148        let (action, infer) = agent.act(&input, &valid, SelectionMode::Play);
1149
1150        // Check that action 4's logit is the highest
1151        let logit_4 = infer.y_conv[4];
1152        let max_other = valid
1153            .iter()
1154            .filter(|&&a| a != 4)
1155            .map(|&a| infer.y_conv[a])
1156            .fold(f64::NEG_INFINITY, f64::max);
1157
1158        eprintln!(
1159            "DIAGNOSTIC: action={action}, logit[4]={logit_4:.4}, max_other={max_other:.4}, \
1160             y_conv={:?}",
1161            infer
1162                .y_conv
1163                .iter()
1164                .map(|v| format!("{v:.3}"))
1165                .collect::<Vec<_>>()
1166        );
1167
1168        assert_eq!(
1169            action, target_action,
1170            "After 200 episodes rewarding action 4, agent should prefer it. Got action {action}"
1171        );
1172    }
1173
1174    // ── config validation tests ────────────────────────────────
1175
1176    #[test]
1177    fn test_new_returns_error_zero_temperature() {
1178        let mut config = default_config();
1179        config.actor.temperature = 0.0;
1180        let err = PcActorCritic::new(config, 42)
1181            .map(|_: PcActorCritic| ())
1182            .unwrap_err();
1183        assert!(format!("{err}").contains("temperature"));
1184    }
1185
1186    #[test]
1187    fn test_new_returns_error_zero_input_size() {
1188        let mut config = default_config();
1189        config.actor.input_size = 0;
1190        config.critic.input_size = 0;
1191        assert!(PcActorCritic::new(config, 42)
1192            .map(|_: PcActorCritic| ())
1193            .is_err());
1194    }
1195
1196    #[test]
1197    fn test_new_returns_error_zero_output_size() {
1198        let mut config = default_config();
1199        config.actor.output_size = 0;
1200        assert!(PcActorCritic::new(config, 42)
1201            .map(|_: PcActorCritic| ())
1202            .is_err());
1203    }
1204
1205    #[test]
1206    fn test_new_returns_error_negative_gamma() {
1207        let mut config = default_config();
1208        config.gamma = -0.1;
1209        let err = PcActorCritic::new(config, 42)
1210            .map(|_: PcActorCritic| ())
1211            .unwrap_err();
1212        assert!(format!("{err}").contains("gamma"));
1213    }
1214
1215    // ── Phase 4 Cycle 4.1: ActivationCache construction and recording ──
1216
1217    #[test]
1218    fn test_activation_cache_new_creates_empty() {
1219        let cache: ActivationCache = ActivationCache::new(3);
1220        assert_eq!(cache.batch_size(), 0);
1221    }
1222
1223    #[test]
1224    fn test_activation_cache_record_increments_batch_size() {
1225        let mut agent: PcActorCritic = make_agent();
1226        let input = vec![0.5; 9];
1227        let valid = vec![0, 1, 2];
1228        let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1229
1230        let num_hidden = infer.hidden_states.len();
1231        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1232        cache.record(&infer.hidden_states);
1233        assert_eq!(cache.batch_size(), 1);
1234    }
1235
1236    #[test]
1237    fn test_activation_cache_record_multiple() {
1238        let mut agent: PcActorCritic = make_agent();
1239        let valid = vec![0, 1, 2];
1240        let init_input = vec![0.5; 9];
1241        let num_hidden = {
1242            let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1243            infer.hidden_states.len()
1244        };
1245
1246        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1247        for i in 0..5 {
1248            let input = vec![i as f64 * 0.1; 9];
1249            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1250            cache.record(&infer.hidden_states);
1251        }
1252        assert_eq!(cache.batch_size(), 5);
1253    }
1254
1255    #[test]
1256    fn test_activation_cache_recorded_values_match_hidden_states() {
1257        let mut agent: PcActorCritic = make_agent();
1258        let input = vec![0.5; 9];
1259        let valid = vec![0, 1, 2];
1260        let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1261
1262        let num_hidden = infer.hidden_states.len();
1263        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1264        cache.record(&infer.hidden_states);
1265
1266        // Verify recorded activations match
1267        for (layer_idx, expected) in infer.hidden_states.iter().enumerate() {
1268            let layer_data = cache.layer(layer_idx);
1269            assert_eq!(layer_data.len(), 1);
1270            assert_eq!(layer_data[0], *expected);
1271        }
1272    }
1273
1274    // ── Phase 4 Cycle 4.2: ActivationCache layer access ────────────
1275
1276    #[test]
1277    fn test_activation_cache_layer_count() {
1278        let mut agent: PcActorCritic = make_agent();
1279        let input = vec![0.5; 9];
1280        let valid = vec![0, 1, 2];
1281        let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1282
1283        let num_hidden = infer.hidden_states.len();
1284        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1285        cache.record(&infer.hidden_states);
1286
1287        assert_eq!(cache.num_layers(), num_hidden);
1288    }
1289
1290    #[test]
1291    fn test_activation_cache_layer_sample_count() {
1292        let mut agent: PcActorCritic = make_agent();
1293        let valid = vec![0, 1, 2];
1294        let init_input = vec![0.5; 9];
1295        let num_hidden = {
1296            let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1297            infer.hidden_states.len()
1298        };
1299
1300        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1301        for i in 0..10 {
1302            let input = vec![i as f64 * 0.1; 9];
1303            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1304            cache.record(&infer.hidden_states);
1305        }
1306
1307        for layer_idx in 0..num_hidden {
1308            assert_eq!(
1309                cache.layer(layer_idx).len(),
1310                10,
1311                "Layer {layer_idx} should have 10 samples"
1312            );
1313        }
1314    }
1315
1316    // ── Phase 7 Cycle 7.1: PcActorCritic::crossover ────────────
1317
1318    fn build_caches_for_agent(
1319        agent: &mut PcActorCritic,
1320        batch_size: usize,
1321    ) -> (ActivationCache, ActivationCache) {
1322        let num_actor_hidden = agent.config.actor.hidden_layers.len();
1323        let num_critic_hidden = agent.config.critic.hidden_layers.len();
1324        let mut actor_cache: ActivationCache = ActivationCache::new(num_actor_hidden);
1325        let mut critic_cache: ActivationCache = ActivationCache::new(num_critic_hidden);
1326        let valid: Vec<usize> = (0..agent.config.actor.output_size).collect();
1327        for i in 0..batch_size {
1328            let input: Vec<f64> = (0..agent.config.actor.input_size)
1329                .map(|j| ((i * 9 + j) as f64 * 0.1).sin())
1330                .collect();
1331            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1332            actor_cache.record(&infer.hidden_states);
1333            let mut critic_input = input;
1334            critic_input.extend_from_slice(&infer.latent_concat);
1335            let (_value, critic_hidden) = agent.critic.forward_with_hidden(&critic_input);
1336            critic_cache.record(&critic_hidden);
1337        }
1338        (actor_cache, critic_cache)
1339    }
1340
1341    #[test]
1342    fn test_agent_crossover_produces_valid_agent() {
1343        let config = default_config();
1344        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1345        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1346
1347        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1348        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1349
1350        let child: PcActorCritic = PcActorCritic::crossover(
1351            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1352        )
1353        .unwrap();
1354
1355        assert_eq!(
1356            child.config.actor.hidden_layers.len(),
1357            agent_a.config.actor.hidden_layers.len()
1358        );
1359    }
1360
1361    #[test]
1362    fn test_agent_crossover_actor_weights_differ() {
1363        let config = default_config();
1364        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1365        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1366
1367        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1368        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1369
1370        let child: PcActorCritic = PcActorCritic::crossover(
1371            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1372        )
1373        .unwrap();
1374
1375        assert_ne!(
1376            child.actor.layers[0].weights.data,
1377            agent_a.actor.layers[0].weights.data
1378        );
1379        assert_ne!(
1380            child.actor.layers[0].weights.data,
1381            agent_b.actor.layers[0].weights.data
1382        );
1383    }
1384
1385    #[test]
1386    fn test_agent_crossover_critic_weights_differ() {
1387        let config = default_config();
1388        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1389        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1390
1391        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1392        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1393
1394        let child: PcActorCritic = PcActorCritic::crossover(
1395            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1396        )
1397        .unwrap();
1398
1399        assert_ne!(
1400            child.critic.layers[0].weights.data,
1401            agent_a.critic.layers[0].weights.data
1402        );
1403        assert_ne!(
1404            child.critic.layers[0].weights.data,
1405            agent_b.critic.layers[0].weights.data
1406        );
1407    }
1408
1409    // ── Phase 7 Cycle 7.2: Integration — full GA workflow ───────
1410
1411    #[test]
1412    fn test_agent_crossover_child_can_infer() {
1413        let config = default_config();
1414        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1415        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1416
1417        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1418        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1419
1420        let mut child: PcActorCritic = PcActorCritic::crossover(
1421            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1422        )
1423        .unwrap();
1424
1425        let input = vec![0.5; 9];
1426        let valid = vec![0, 1, 2, 3, 4];
1427        let (action, _) = child.act(&input, &valid, SelectionMode::Training);
1428        assert!(valid.contains(&action), "Action {action} not in valid set");
1429    }
1430
1431    #[test]
1432    fn test_agent_crossover_child_can_learn() {
1433        let config = default_config();
1434        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1435        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1436
1437        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1438        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1439
1440        let mut child: PcActorCritic = PcActorCritic::crossover(
1441            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1442        )
1443        .unwrap();
1444
1445        let trajectory = make_trajectory(&mut child);
1446        let loss = child.learn(&trajectory);
1447        assert!(loss.is_finite(), "Child learn loss not finite: {loss}");
1448    }
1449
1450    #[test]
1451    fn test_agent_crossover_mismatched_batch_size_error() {
1452        let config = default_config();
1453        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1454        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1455
1456        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1457        let (ac_b, _cc_b) = build_caches_for_agent(&mut agent_b, 30); // different batch
1458        let (_, cc_b_match) = build_caches_for_agent(&mut agent_b, 50);
1459
1460        // Actor batch mismatch
1461        let result = PcActorCritic::crossover(
1462            &agent_a,
1463            &agent_b,
1464            &ac_a,
1465            &ac_b,
1466            &cc_a,
1467            &cc_b_match,
1468            0.5,
1469            config,
1470            99,
1471        );
1472        assert!(result.is_err(), "Mismatched actor batch sizes should error");
1473    }
1474
1475    // ── Fix #2: Separate critic caches in crossover ────────────
1476
1477    #[test]
1478    fn test_agent_crossover_with_separate_critic_caches() {
1479        let config = default_config();
1480        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1481        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1482
1483        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1484        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1485
1486        let child: PcActorCritic = PcActorCritic::crossover(
1487            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1488        )
1489        .unwrap();
1490
1491        assert_eq!(child.critic.layers.len(), agent_a.critic.layers.len());
1492    }
1493
1494    #[test]
1495    fn test_agent_crossover_critic_uses_own_caches() {
1496        let config = default_config();
1497        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1498        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1499
1500        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1501        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1502
1503        let child: PcActorCritic = PcActorCritic::crossover(
1504            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1505        )
1506        .unwrap();
1507
1508        assert_ne!(
1509            child.critic.layers[0].weights.data,
1510            agent_a.critic.layers[0].weights.data
1511        );
1512        assert_ne!(
1513            child.critic.layers[0].weights.data,
1514            agent_b.critic.layers[0].weights.data
1515        );
1516    }
1517
1518    #[test]
1519    fn test_agent_crossover_mismatched_critic_batch_error() {
1520        let config = default_config();
1521        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1522        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1523
1524        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1525        let (ac_b, _) = build_caches_for_agent(&mut agent_b, 50);
1526        // Build critic cache with different batch size
1527        let (_, cc_b_small) = build_caches_for_agent(&mut agent_b, 30);
1528
1529        let result = PcActorCritic::crossover(
1530            &agent_a,
1531            &agent_b,
1532            &ac_a,
1533            &ac_b,
1534            &cc_a,
1535            &cc_b_small,
1536            0.5,
1537            config,
1538            99,
1539        );
1540        assert!(
1541            result.is_err(),
1542            "Mismatched critic batch sizes should error"
1543        );
1544    }
1545
1546    // ── Phase 7 Cycle 7.3: lib.rs re-exports ────────────────────
1547
1548    #[test]
1549    fn test_activation_cache_accessible_from_crate() {
1550        // Verify ActivationCache is accessible via pc_actor_critic module
1551        let _cache: crate::pc_actor_critic::ActivationCache = ActivationCache::new(1);
1552    }
1553
1554    #[test]
1555    fn test_cca_neuron_alignment_accessible_from_crate() {
1556        // Verify cca_neuron_alignment is accessible via matrix module
1557        use crate::linalg::cpu::CpuLinAlg;
1558        use crate::linalg::LinAlg;
1559        let mat = CpuLinAlg::zeros_mat(10, 3);
1560        let _perm = crate::matrix::cca_neuron_alignment::<CpuLinAlg>(&mat, &mat).unwrap();
1561    }
1562}