Skip to main content

pc_rl_core/
pc_actor_critic.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! Integrated PC Actor-Critic agent.
6//!
7//! Combines [`PcActor`] for action selection via predictive coding inference
8//! with [`MlpCritic`] for value estimation. Supports REINFORCE episodic
9//! learning, TD(0) continuous learning, surprise-based scheduling, and
10//! entropy regularization.
11//!
12//! Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
13
14use std::collections::VecDeque;
15
16use rand::rngs::StdRng;
17use serde::{Deserialize, Serialize};
18
19use crate::error::PcError;
20use crate::linalg::cpu::CpuLinAlg;
21use crate::linalg::LinAlg;
22use crate::mlp_critic::{MlpCritic, MlpCriticConfig};
23use crate::pc_actor::{InferResult, PcActor, PcActorConfig, SelectionMode};
24
25/// Configuration for the integrated PC Actor-Critic agent.
26///
27/// # Examples
28///
29/// ```
30/// use pc_rl_core::activation::Activation;
31/// use pc_rl_core::layer::LayerDef;
32/// use pc_rl_core::mlp_critic::MlpCriticConfig;
33/// use pc_rl_core::pc_actor::PcActorConfig;
34/// use pc_rl_core::pc_actor_critic::PcActorCriticConfig;
35///
36/// let config = PcActorCriticConfig {
37///     actor: PcActorConfig {
38///         input_size: 9,
39///         hidden_layers: vec![LayerDef { size: 18, activation: Activation::Tanh }],
40///         output_size: 9,
41///         output_activation: Activation::Tanh,
42///         alpha: 0.1, tol: 0.01, min_steps: 1, max_steps: 20,
43///         lr_weights: 0.01, synchronous: true, temperature: 1.0,
44///         local_lambda: 1.0,
45///         residual: false,
46///         rezero_init: 0.001,
47///     },
48///     critic: MlpCriticConfig {
49///         input_size: 27,
50///         hidden_layers: vec![LayerDef { size: 36, activation: Activation::Tanh }],
51///         output_activation: Activation::Linear,
52///         lr: 0.005,
53///     },
54///     gamma: 0.95,
55///     surprise_low: 0.02,
56///     surprise_high: 0.15,
57///     adaptive_surprise: false,
58///     entropy_coeff: 0.01,
59/// };
60/// ```
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct PcActorCriticConfig {
63    /// Actor (PC network) configuration.
64    pub actor: PcActorConfig,
65    /// Critic (MLP value function) configuration.
66    pub critic: MlpCriticConfig,
67    /// Discount factor for computing returns.
68    pub gamma: f64,
69    /// Surprise threshold below which learning rate is scaled down.
70    pub surprise_low: f64,
71    /// Surprise threshold above which learning rate is scaled up.
72    pub surprise_high: f64,
73    /// Whether to adaptively recalibrate surprise thresholds.
74    pub adaptive_surprise: bool,
75    /// Entropy regularization coefficient.
76    pub entropy_coeff: f64,
77}
78
79/// A single step in a trajectory collected during an episode.
80///
81/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
82#[derive(Debug, Clone)]
83pub struct TrajectoryStep<L: LinAlg = CpuLinAlg> {
84    /// Board state input vector.
85    pub input: L::Vector,
86    /// Concatenated hidden layer activations from inference.
87    pub latent_concat: L::Vector,
88    /// Converged output logits from inference.
89    pub y_conv: L::Vector,
90    /// Per-layer hidden state activations from inference (for backprop).
91    pub hidden_states: Vec<L::Vector>,
92    /// Per-layer prediction errors from the PC inference loop.
93    pub prediction_errors: Vec<L::Vector>,
94    /// Per-layer tanh components for residual layers (for correct backward pass).
95    pub tanh_components: Vec<Option<L::Vector>>,
96    /// Action taken at this step.
97    pub action: usize,
98    /// Valid actions at this step (needed for masked softmax).
99    pub valid_actions: Vec<usize>,
100    /// Reward received after taking the action.
101    pub reward: f64,
102    /// Prediction error from inference.
103    pub surprise_score: f64,
104    /// Number of inference steps used.
105    pub steps_used: usize,
106}
107
108/// Cache for hidden layer activations captured during inference.
109///
110/// Used by CCA neuron alignment during crossover to compare functional
111/// neuron responses between parent networks. Activations are recorded
112/// during normal fitness evaluation at zero additional compute cost.
113///
114/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
115///
116/// # Examples
117///
118/// ```
119/// use pc_rl_core::pc_actor_critic::ActivationCache;
120/// use pc_rl_core::linalg::cpu::CpuLinAlg;
121///
122/// let cache: ActivationCache<CpuLinAlg> = ActivationCache::new(2);
123/// assert_eq!(cache.batch_size(), 0);
124/// assert_eq!(cache.num_layers(), 2);
125/// ```
126#[derive(Debug, Clone)]
127pub struct ActivationCache<L: LinAlg = CpuLinAlg> {
128    /// activations[layer_idx][batch_sample_idx] = neuron activations.
129    layers: Vec<Vec<L::Vector>>,
130}
131
132impl<L: LinAlg> ActivationCache<L> {
133    /// Creates a new empty cache with the given number of hidden layers.
134    ///
135    /// # Arguments
136    ///
137    /// * `num_layers` - Number of hidden layers to track.
138    pub fn new(num_layers: usize) -> Self {
139        Self {
140            layers: (0..num_layers).map(|_| Vec::new()).collect(),
141        }
142    }
143
144    /// Returns the number of samples recorded so far.
145    pub fn batch_size(&self) -> usize {
146        self.layers.first().map_or(0, |l| l.len())
147    }
148
149    /// Returns the number of hidden layers in the cache.
150    pub fn num_layers(&self) -> usize {
151        self.layers.len()
152    }
153
154    /// Records hidden layer activations from a single inference.
155    ///
156    /// # Arguments
157    ///
158    /// * `hidden_states` - Per-layer activations from `InferResult::hidden_states`.
159    pub fn record(&mut self, hidden_states: &[L::Vector]) {
160        for (layer, state) in self.layers.iter_mut().zip(hidden_states.iter()) {
161            layer.push(state.clone());
162        }
163    }
164
165    /// Returns the recorded activations for a given layer.
166    ///
167    /// # Arguments
168    ///
169    /// * `layer_idx` - Index of the hidden layer.
170    pub fn layer(&self, layer_idx: usize) -> &[L::Vector] {
171        &self.layers[layer_idx]
172    }
173}
174
175/// Integrated PC Actor-Critic agent.
176///
177/// Combines a predictive coding actor with an MLP critic for
178/// reinforcement learning with surprise-based scheduling.
179///
180/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
181#[derive(Debug)]
182pub struct PcActorCritic<L: LinAlg = CpuLinAlg> {
183    /// The PC actor network.
184    pub(crate) actor: PcActor<L>,
185    /// The MLP critic (value function).
186    pub(crate) critic: MlpCritic<L>,
187    /// Agent configuration.
188    pub config: PcActorCriticConfig,
189    /// Random number generator for action selection.
190    rng: StdRng,
191    /// Circular buffer of recent surprise scores for adaptive thresholds.
192    surprise_buffer: VecDeque<f64>,
193}
194
195impl<L: LinAlg> PcActorCritic<L> {
196    /// Creates a new PC Actor-Critic agent.
197    ///
198    /// # Arguments
199    ///
200    /// * `config` - Agent configuration with actor, critic, and learning parameters.
201    /// * `seed` - Random seed for reproducibility.
202    /// # Errors
203    ///
204    /// Returns `PcError::ConfigValidation` if gamma is out of `[0.0, 1.0]`,
205    /// or if actor/critic config is invalid.
206    pub fn new(config: PcActorCriticConfig, seed: u64) -> Result<Self, PcError> {
207        if !(0.0..=1.0).contains(&config.gamma) {
208            return Err(PcError::ConfigValidation(format!(
209                "gamma must be in [0.0, 1.0], got {}",
210                config.gamma
211            )));
212        }
213
214        use rand::SeedableRng;
215        let mut rng = StdRng::seed_from_u64(seed);
216        let actor = PcActor::<L>::new(config.actor.clone(), &mut rng)?;
217        let critic = MlpCritic::<L>::new(config.critic.clone(), &mut rng)?;
218        Ok(Self {
219            actor,
220            critic,
221            config,
222            rng,
223            surprise_buffer: VecDeque::new(),
224        })
225    }
226
227    /// Creates a child agent by crossing over two parent agents using CCA neuron alignment.
228    ///
229    /// Delegates to `PcActor::crossover` and `MlpCritic::crossover`, converting
230    /// activation caches to the matrix format expected by CCA alignment.
231    ///
232    /// # Arguments
233    ///
234    /// * `parent_a` - First parent agent (reference, typically higher fitness).
235    /// * `parent_b` - Second parent agent.
236    /// * `cache_a` - Activation cache for parent A on the reference batch.
237    /// * `cache_b` - Activation cache for parent B on the reference batch.
238    /// * `alpha` - Blending weight: 1.0 = all A, 0.0 = all B.
239    /// * `child_config` - Configuration for the child agent.
240    /// * `seed` - Random seed for the child's RNG.
241    ///
242    /// # Errors
243    ///
244    /// Returns `PcError::DimensionMismatch` if activation caches have different
245    /// batch sizes. Returns `PcError::ConfigValidation` if child config is invalid.
246    #[allow(clippy::too_many_arguments)]
247    pub fn crossover(
248        parent_a: &PcActorCritic<L>,
249        parent_b: &PcActorCritic<L>,
250        actor_cache_a: &ActivationCache<L>,
251        actor_cache_b: &ActivationCache<L>,
252        critic_cache_a: &ActivationCache<L>,
253        critic_cache_b: &ActivationCache<L>,
254        alpha: f64,
255        child_config: PcActorCriticConfig,
256        seed: u64,
257    ) -> Result<Self, PcError> {
258        // Validate actor batch sizes match
259        if actor_cache_a.batch_size() != actor_cache_b.batch_size() {
260            return Err(PcError::DimensionMismatch {
261                expected: actor_cache_a.batch_size(),
262                got: actor_cache_b.batch_size(),
263                context: "actor activation cache batch sizes must match for crossover",
264            });
265        }
266        // Validate critic batch sizes match
267        if critic_cache_a.batch_size() != critic_cache_b.batch_size() {
268            return Err(PcError::DimensionMismatch {
269                expected: critic_cache_a.batch_size(),
270                got: critic_cache_b.batch_size(),
271                context: "critic activation cache batch sizes must match for crossover",
272            });
273        }
274
275        // Convert caches to matrices [batch × neurons] for CCA
276        let actor_cache_mats_a = cache_to_matrices::<L>(actor_cache_a);
277        let actor_cache_mats_b = cache_to_matrices::<L>(actor_cache_b);
278        let critic_cache_mats_a = cache_to_matrices::<L>(critic_cache_a);
279        let critic_cache_mats_b = cache_to_matrices::<L>(critic_cache_b);
280
281        use rand::SeedableRng;
282        let mut rng = StdRng::seed_from_u64(seed);
283
284        // Crossover actor with actor-specific caches
285        let actor = PcActor::<L>::crossover(
286            &parent_a.actor,
287            &parent_b.actor,
288            &actor_cache_mats_a,
289            &actor_cache_mats_b,
290            alpha,
291            child_config.actor.clone(),
292            &mut rng,
293        )?;
294
295        // Crossover critic with critic-specific caches
296        let critic = MlpCritic::<L>::crossover(
297            &parent_a.critic,
298            &parent_b.critic,
299            &critic_cache_mats_a,
300            &critic_cache_mats_b,
301            alpha,
302            child_config.critic.clone(),
303            &mut rng,
304        )?;
305
306        Ok(Self {
307            actor,
308            critic,
309            config: child_config,
310            rng,
311            surprise_buffer: VecDeque::new(),
312        })
313    }
314
315    /// Reconstructs an agent from pre-built components (used by serializer).
316    ///
317    /// # Arguments
318    ///
319    /// * `config` - Agent configuration.
320    /// * `actor` - Pre-built PC actor with loaded weights.
321    /// * `critic` - Pre-built MLP critic with loaded weights.
322    /// * `rng` - Random number generator.
323    pub fn from_parts(
324        config: PcActorCriticConfig,
325        actor: PcActor<L>,
326        critic: MlpCritic<L>,
327        rng: StdRng,
328    ) -> Self {
329        Self {
330            actor,
331            critic,
332            config,
333            rng,
334            surprise_buffer: VecDeque::new(),
335        }
336    }
337
338    /// Runs PC inference without selecting an action or modifying RNG state.
339    ///
340    /// Use this when you only need the inference result (e.g., for TD(0)
341    /// next-state evaluation) without side effects.
342    ///
343    /// # Arguments
344    ///
345    /// * `input` - Board state vector.
346    ///
347    /// # Panics
348    ///
349    /// Panics if `input.len() != config.actor.input_size`.
350    pub fn infer(&self, input: &[f64]) -> InferResult<L> {
351        self.actor.infer(input)
352    }
353
354    /// Selects an action given the current state.
355    ///
356    /// Runs PC inference on the input, then selects an action using the
357    /// converged logits and the specified selection mode.
358    ///
359    /// # Arguments
360    ///
361    /// * `input` - Board state vector.
362    /// * `valid_actions` - Indices of legal actions.
363    /// * `mode` - Training (stochastic) or Play (deterministic).
364    ///
365    /// # Panics
366    ///
367    /// Panics if `valid_actions` is empty.
368    pub fn act(
369        &mut self,
370        input: &[f64],
371        valid_actions: &[usize],
372        mode: SelectionMode,
373    ) -> (usize, InferResult<L>) {
374        let infer_result = self.actor.infer(input);
375        let action =
376            self.actor
377                .select_action(&infer_result.y_conv, valid_actions, mode, &mut self.rng);
378        (action, infer_result)
379    }
380
381    /// Learns from a complete episode trajectory using REINFORCE with baseline.
382    ///
383    /// Empty trajectory returns 0.0 without modifying weights. Otherwise computes
384    /// discounted returns, advantages, and updates both actor and critic.
385    ///
386    /// # Arguments
387    ///
388    /// * `trajectory` - Sequence of steps from an episode.
389    ///
390    /// # Returns
391    ///
392    /// Average critic loss over the trajectory.
393    pub fn learn(&mut self, trajectory: &[TrajectoryStep<L>]) -> f64 {
394        if trajectory.is_empty() {
395            return 0.0;
396        }
397
398        let n = trajectory.len();
399
400        // Compute discounted returns backward
401        let mut returns = vec![0.0; n];
402        returns[n - 1] = trajectory[n - 1].reward;
403        for t in (0..n - 1).rev() {
404            returns[t] = trajectory[t].reward + self.config.gamma * returns[t + 1];
405        }
406
407        let mut total_loss = 0.0;
408
409        for (t, step) in trajectory.iter().enumerate() {
410            // Build critic input: concat(input, latent_concat)
411            let input_vec = L::vec_to_vec(&step.input);
412            let latent_vec = L::vec_to_vec(&step.latent_concat);
413            let mut critic_input = input_vec.clone();
414            critic_input.extend_from_slice(&latent_vec);
415
416            // V(s)
417            let value = self.critic.forward(&critic_input);
418            let advantage = returns[t] - value;
419
420            // Update critic toward discounted return
421            let loss = self.critic.update(&critic_input, returns[t]);
422            total_loss += loss;
423
424            // Policy gradient
425            let y_conv_vec = L::vec_to_vec(&step.y_conv);
426            let scaled: Vec<f64> = y_conv_vec
427                .iter()
428                .map(|&v| v / self.actor.config.temperature)
429                .collect();
430            let scaled_l = L::vec_from_slice(&scaled);
431            let pi_l = L::softmax_masked(&scaled_l, &step.valid_actions);
432            let pi = L::vec_to_vec(&pi_l);
433
434            let mut delta = vec![0.0; pi.len()];
435            for &i in &step.valid_actions {
436                delta[i] = pi[i];
437            }
438            delta[step.action] -= 1.0;
439
440            // Scale by advantage
441            for &i in &step.valid_actions {
442                delta[i] *= advantage;
443            }
444
445            // Entropy regularization
446            for &i in &step.valid_actions {
447                let log_pi = (pi[i].max(1e-10)).ln();
448                delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
449            }
450
451            // Compute surprise scale and update actor using stored hidden_states
452            let s_scale = self.surprise_scale(step.surprise_score);
453
454            let stored_infer = InferResult {
455                y_conv: step.y_conv.clone(),
456                latent_concat: step.latent_concat.clone(),
457                hidden_states: step.hidden_states.clone(),
458                prediction_errors: step.prediction_errors.clone(),
459                surprise_score: step.surprise_score,
460                steps_used: step.steps_used,
461                converged: false,
462                tanh_components: step.tanh_components.clone(),
463            };
464            self.actor
465                .update_weights(&delta, &stored_infer, &input_vec, s_scale);
466
467            // Push surprise to adaptive buffer
468            if self.config.adaptive_surprise {
469                self.push_surprise(step.surprise_score);
470            }
471        }
472
473        total_loss / n as f64
474    }
475
476    /// Single-step TD(0) continuous learning.
477    ///
478    /// # Arguments
479    ///
480    /// * `input` - Current state.
481    /// * `infer` - Inference result from `act` at current state.
482    /// * `action` - Action taken.
483    /// * `valid_actions` - Valid actions at current state.
484    /// * `reward` - Reward received.
485    /// * `next_input` - Next state.
486    /// * `next_infer` - Inference result from `act` at next state.
487    /// * `terminal` - Whether the episode ended.
488    ///
489    /// # Returns
490    ///
491    /// Critic loss for this step.
492    #[allow(clippy::too_many_arguments)]
493    pub fn learn_continuous(
494        &mut self,
495        input: &[f64],
496        infer: &InferResult<L>,
497        action: usize,
498        valid_actions: &[usize],
499        reward: f64,
500        next_input: &[f64],
501        next_infer: &InferResult<L>,
502        terminal: bool,
503    ) -> f64 {
504        // Build critic inputs
505        let latent_vec = L::vec_to_vec(&infer.latent_concat);
506        let mut critic_input = input.to_vec();
507        critic_input.extend_from_slice(&latent_vec);
508
509        let next_latent_vec = L::vec_to_vec(&next_infer.latent_concat);
510        let mut next_critic_input = next_input.to_vec();
511        next_critic_input.extend_from_slice(&next_latent_vec);
512
513        let v_s = self.critic.forward(&critic_input);
514        let v_next = if terminal {
515            0.0
516        } else {
517            self.critic.forward(&next_critic_input)
518        };
519
520        let target = reward
521            + if terminal {
522                0.0
523            } else {
524                self.config.gamma * v_next
525            };
526        let td_error = target - v_s;
527
528        // Update critic
529        let loss = self.critic.update(&critic_input, target);
530
531        // Policy gradient (same formula as learn, but scaled by td_error)
532        let y_conv_vec = L::vec_to_vec(&infer.y_conv);
533        let scaled: Vec<f64> = y_conv_vec
534            .iter()
535            .map(|&v| v / self.actor.config.temperature)
536            .collect();
537        let scaled_l = L::vec_from_slice(&scaled);
538        let pi_l = L::softmax_masked(&scaled_l, valid_actions);
539        let pi = L::vec_to_vec(&pi_l);
540
541        let mut delta = vec![0.0; pi.len()];
542        for &i in valid_actions {
543            delta[i] = pi[i];
544        }
545        delta[action] -= 1.0;
546
547        for &i in valid_actions {
548            delta[i] *= td_error;
549        }
550
551        // Entropy regularization
552        for &i in valid_actions {
553            let log_pi = (pi[i].max(1e-10)).ln();
554            delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
555        }
556
557        let s_scale = self.surprise_scale(infer.surprise_score);
558        self.actor.update_weights(&delta, infer, input, s_scale);
559
560        if self.config.adaptive_surprise {
561            self.push_surprise(infer.surprise_score);
562        }
563
564        loss
565    }
566
567    /// Computes the learning rate scale factor based on surprise score.
568    ///
569    /// - surprise <= low → 0.1
570    /// - surprise >= high → 2.0
571    /// - Between → linear interpolation from 0.1 to 2.0
572    ///
573    /// If adaptive surprise is enabled and the buffer has >= 10 entries,
574    /// thresholds are dynamically recomputed from the buffer statistics.
575    pub fn surprise_scale(&self, surprise: f64) -> f64 {
576        let (low, high) = if self.config.adaptive_surprise && self.surprise_buffer.len() >= 10 {
577            let mean = self.surprise_buffer.iter().sum::<f64>() / self.surprise_buffer.len() as f64;
578            let variance = self
579                .surprise_buffer
580                .iter()
581                .map(|&s| (s - mean) * (s - mean))
582                .sum::<f64>()
583                / self.surprise_buffer.len() as f64;
584            let std = variance.sqrt();
585            let lo = (mean - 0.5 * std).max(0.0);
586            let hi = mean + 1.5 * std;
587            (lo, hi)
588        } else {
589            (self.config.surprise_low, self.config.surprise_high)
590        };
591
592        if surprise <= low {
593            0.1
594        } else if surprise >= high {
595            2.0
596        } else {
597            // Linear interpolation
598            let t = (surprise - low) / (high - low);
599            0.1 + t * (2.0 - 0.1)
600        }
601    }
602
603    /// Pushes a surprise score into the adaptive buffer (circular, max 100).
604    fn push_surprise(&mut self, surprise: f64) {
605        if self.surprise_buffer.len() >= 100 {
606            self.surprise_buffer.pop_front();
607        }
608        self.surprise_buffer.push_back(surprise);
609    }
610}
611
612/// Converts an `ActivationCache` into a vector of matrices `[batch × neurons]`,
613/// one per hidden layer, suitable for CCA alignment.
614fn cache_to_matrices<L: LinAlg>(cache: &ActivationCache<L>) -> Vec<L::Matrix> {
615    let num_layers = cache.num_layers();
616    let batch_size = cache.batch_size();
617    let mut matrices = Vec::with_capacity(num_layers);
618
619    for layer_idx in 0..num_layers {
620        let samples = cache.layer(layer_idx);
621        if samples.is_empty() {
622            matrices.push(L::zeros_mat(0, 0));
623            continue;
624        }
625        let n_neurons = L::vec_len(&samples[0]);
626        let mut mat = L::zeros_mat(batch_size, n_neurons);
627        for (r, sample) in samples.iter().enumerate() {
628            for c in 0..n_neurons {
629                L::mat_set(&mut mat, r, c, L::vec_get(sample, c));
630            }
631        }
632        matrices.push(mat);
633    }
634
635    matrices
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641    use crate::activation::Activation;
642    use crate::layer::LayerDef;
643    use crate::pc_actor::SelectionMode;
644
645    fn default_config() -> PcActorCriticConfig {
646        PcActorCriticConfig {
647            actor: PcActorConfig {
648                input_size: 9,
649                hidden_layers: vec![LayerDef {
650                    size: 18,
651                    activation: Activation::Tanh,
652                }],
653                output_size: 9,
654                output_activation: Activation::Tanh,
655                alpha: 0.1,
656                tol: 0.01,
657                min_steps: 1,
658                max_steps: 20,
659                lr_weights: 0.01,
660                synchronous: true,
661                temperature: 1.0,
662                local_lambda: 1.0,
663                residual: false,
664                rezero_init: 0.001,
665            },
666            critic: MlpCriticConfig {
667                input_size: 27,
668                hidden_layers: vec![LayerDef {
669                    size: 36,
670                    activation: Activation::Tanh,
671                }],
672                output_activation: Activation::Linear,
673                lr: 0.005,
674            },
675            gamma: 0.95,
676            surprise_low: 0.02,
677            surprise_high: 0.15,
678            adaptive_surprise: false,
679            entropy_coeff: 0.01,
680        }
681    }
682
683    fn make_agent() -> PcActorCritic {
684        let agent: PcActorCritic = PcActorCritic::new(default_config(), 42).unwrap();
685        agent
686    }
687
688    fn make_trajectory(agent: &mut PcActorCritic) -> Vec<TrajectoryStep> {
689        let input = vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5];
690        let valid = vec![2, 7];
691        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
692        vec![TrajectoryStep {
693            input,
694            latent_concat: infer.latent_concat,
695            y_conv: infer.y_conv,
696            hidden_states: infer.hidden_states,
697            prediction_errors: infer.prediction_errors,
698            tanh_components: infer.tanh_components,
699            action,
700            valid_actions: valid,
701            reward: 1.0,
702            surprise_score: infer.surprise_score,
703            steps_used: infer.steps_used,
704        }]
705    }
706
707    // ── learn tests ───────────────────────────────────────────────
708
709    #[test]
710    fn test_learn_empty_returns_zero_without_modifying_weights() {
711        let mut agent: PcActorCritic = make_agent();
712        let w_before = agent.actor.layers[0].weights.data.clone();
713        let cw_before = agent.critic.layers[0].weights.data.clone();
714        let loss = agent.learn(&[]);
715        assert_eq!(loss, 0.0);
716        assert_eq!(agent.actor.layers[0].weights.data, w_before);
717        assert_eq!(agent.critic.layers[0].weights.data, cw_before);
718    }
719
720    #[test]
721    fn test_learn_updates_actor_weights() {
722        let mut agent: PcActorCritic = make_agent();
723        let trajectory = make_trajectory(&mut agent);
724        let w_before = agent.actor.layers[0].weights.data.clone();
725        let _ = agent.learn(&trajectory);
726        assert_ne!(agent.actor.layers[0].weights.data, w_before);
727    }
728
729    #[test]
730    fn test_learn_updates_critic_weights() {
731        let mut agent: PcActorCritic = make_agent();
732        let trajectory = make_trajectory(&mut agent);
733        let w_before = agent.critic.layers[0].weights.data.clone();
734        let _ = agent.learn(&trajectory);
735        assert_ne!(agent.critic.layers[0].weights.data, w_before);
736    }
737
738    #[test]
739    fn test_learn_returns_finite_nonneg_loss() {
740        let mut agent: PcActorCritic = make_agent();
741        let trajectory = make_trajectory(&mut agent);
742        let loss = agent.learn(&trajectory);
743        assert!(loss.is_finite(), "Loss {loss} is not finite");
744        assert!(loss >= 0.0, "Loss {loss} is negative");
745    }
746
747    #[test]
748    fn test_learn_single_step_trajectory() {
749        let mut agent: PcActorCritic = make_agent();
750        let input = vec![0.5; 9];
751        let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
752        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
753        let trajectory = vec![TrajectoryStep {
754            input,
755            latent_concat: infer.latent_concat,
756            y_conv: infer.y_conv,
757            hidden_states: infer.hidden_states,
758            prediction_errors: infer.prediction_errors,
759            tanh_components: infer.tanh_components,
760            action,
761            valid_actions: valid,
762            reward: -1.0,
763            surprise_score: infer.surprise_score,
764            steps_used: infer.steps_used,
765        }];
766        let loss = agent.learn(&trajectory);
767        assert!(loss.is_finite());
768    }
769
770    #[test]
771    fn test_learn_multi_step_uses_stored_hidden_states() {
772        // Build a 3-step trajectory to exercise multi-step learning
773        let mut agent: PcActorCritic = make_agent();
774        let inputs = [
775            vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5],
776            vec![0.5, 0.5, -1.0, 0.0, 1.0, -0.5, 0.0, -1.0, 0.5],
777            vec![-1.0, 0.0, 1.0, -0.5, 0.5, 0.0, 1.0, -1.0, -0.5],
778        ];
779        let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
780
781        let mut trajectory = Vec::new();
782        for (i, inp) in inputs.iter().enumerate() {
783            let (action, infer) = agent.act(inp, &valid, SelectionMode::Training);
784            trajectory.push(TrajectoryStep {
785                input: inp.clone(),
786                latent_concat: infer.latent_concat,
787                y_conv: infer.y_conv,
788                hidden_states: infer.hidden_states,
789                prediction_errors: infer.prediction_errors,
790                tanh_components: infer.tanh_components,
791                action,
792                valid_actions: valid.clone(),
793                reward: if i == 2 { 1.0 } else { 0.0 },
794                surprise_score: infer.surprise_score,
795                steps_used: infer.steps_used,
796            });
797        }
798
799        let loss = agent.learn(&trajectory);
800        assert!(
801            loss.is_finite(),
802            "Multi-step learn should produce finite loss"
803        );
804        assert!(loss >= 0.0);
805    }
806
807    // ── learn_continuous tests ────────────────────────────────────
808
809    #[test]
810    fn test_learn_continuous_nonterminal_uses_next_value() {
811        let mut agent: PcActorCritic = make_agent();
812        let input = vec![0.5; 9];
813        let next_input = vec![-0.5; 9];
814        let valid = vec![0, 1, 2];
815        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
816        let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
817
818        // Non-terminal: should incorporate next value
819        let loss = agent.learn_continuous(
820            &input,
821            &infer,
822            action,
823            &valid,
824            0.5,
825            &next_input,
826            &next_infer,
827            false,
828        );
829        assert!(loss.is_finite());
830    }
831
832    #[test]
833    fn test_learn_continuous_terminal_uses_reward_only() {
834        let mut agent: PcActorCritic = make_agent();
835        let input = vec![0.5; 9];
836        let next_input = vec![0.0; 9];
837        let valid = vec![0, 1, 2];
838        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
839        let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
840
841        // Terminal: target = reward only (no gamma * V(s'))
842        let loss = agent.learn_continuous(
843            &input,
844            &infer,
845            action,
846            &valid,
847            1.0,
848            &next_input,
849            &next_infer,
850            true,
851        );
852        assert!(loss.is_finite());
853    }
854
855    #[test]
856    fn test_learn_continuous_terminal_and_nonterminal_produce_different_updates() {
857        // Create two identical agents
858        let config = default_config();
859        let mut agent_term: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
860        let mut agent_nonterm: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
861
862        let input = vec![0.5; 9];
863        let next_input = vec![-0.5; 9];
864        let valid = vec![0, 1, 2];
865
866        // Use identical actions and inferences
867        let (action, infer) = agent_term.act(&input, &valid, SelectionMode::Training);
868        let (_, next_infer) = agent_term.act(&next_input, &valid, SelectionMode::Training);
869
870        // Clone infer for the non-terminal agent (same starting point)
871        let (action2, infer2) = agent_nonterm.act(&input, &valid, SelectionMode::Training);
872        let (_, next_infer2) = agent_nonterm.act(&next_input, &valid, SelectionMode::Training);
873
874        // Terminal update
875        let loss_term = agent_term.learn_continuous(
876            &input,
877            &infer,
878            action,
879            &valid,
880            1.0,
881            &next_input,
882            &next_infer,
883            true,
884        );
885
886        // Non-terminal update with same reward
887        let loss_nonterm = agent_nonterm.learn_continuous(
888            &input,
889            &infer2,
890            action2,
891            &valid,
892            1.0,
893            &next_input,
894            &next_infer2,
895            false,
896        );
897
898        // The losses should differ because terminal uses target=reward
899        // while non-terminal uses target=reward+gamma*V(s')
900        assert!(
901            (loss_term - loss_nonterm).abs() > 1e-15,
902            "Terminal and non-terminal should produce different losses: {loss_term} vs {loss_nonterm}"
903        );
904    }
905
906    #[test]
907    fn test_learn_continuous_updates_actor() {
908        let mut agent: PcActorCritic = make_agent();
909        let input = vec![0.5; 9];
910        let next_input = vec![-0.5; 9];
911        let valid = vec![0, 1, 2];
912        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
913        let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
914        let w_before = agent.actor.layers[0].weights.data.clone();
915        let _ = agent.learn_continuous(
916            &input,
917            &infer,
918            action,
919            &valid,
920            1.0,
921            &next_input,
922            &next_infer,
923            false,
924        );
925        assert_ne!(agent.actor.layers[0].weights.data, w_before);
926    }
927
928    // ── surprise_scale tests ─────────────────────────────────────
929
930    #[test]
931    fn test_surprise_scale_below_low() {
932        let agent: PcActorCritic = make_agent();
933        let scale = agent.surprise_scale(0.01); // below low=0.02
934        assert!((scale - 0.1).abs() < 1e-12, "Expected 0.1, got {scale}");
935    }
936
937    #[test]
938    fn test_surprise_scale_above_high() {
939        let agent: PcActorCritic = make_agent();
940        let scale = agent.surprise_scale(0.20); // above high=0.15
941        assert!((scale - 2.0).abs() < 1e-12, "Expected 2.0, got {scale}");
942    }
943
944    #[test]
945    fn test_surprise_scale_midpoint_in_range() {
946        let agent: PcActorCritic = make_agent();
947        let midpoint = (0.02 + 0.15) / 2.0;
948        let scale = agent.surprise_scale(midpoint);
949        assert!(
950            scale > 0.1 && scale < 2.0,
951            "Midpoint scale {scale} out of range"
952        );
953    }
954
955    #[test]
956    fn test_surprise_scale_monotone_increasing() {
957        let agent: PcActorCritic = make_agent();
958        let s1 = agent.surprise_scale(0.01);
959        let s2 = agent.surprise_scale(0.05);
960        let s3 = agent.surprise_scale(0.10);
961        let s4 = agent.surprise_scale(0.20);
962        assert!(s1 <= s2, "s1={s1} > s2={s2}");
963        assert!(s2 <= s3, "s2={s2} > s3={s3}");
964        assert!(s3 <= s4, "s3={s3} > s4={s4}");
965    }
966
967    #[test]
968    fn test_adaptive_surprise_recalibrates_thresholds_after_many_episodes() {
969        let mut config = default_config();
970        config.adaptive_surprise = true;
971        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
972
973        // Fill buffer with varied surprise scores to get nonzero std
974        for i in 0..15 {
975            agent.push_surprise(0.1 + 0.02 * i as f64);
976        }
977
978        // mean ≈ 0.24, std ≈ 0.089
979        // adaptive low = max(0, mean - 0.5*std) ≈ 0.196
980        // adaptive high = mean + 1.5*std ≈ 0.373
981        // These differ from the static defaults (0.02, 0.15)
982
983        // Something well below adaptive low should get 0.1
984        let scale_low = agent.surprise_scale(0.0);
985        assert!(
986            (scale_low - 0.1).abs() < 1e-12,
987            "Expected 0.1 below adaptive low: got {scale_low}"
988        );
989
990        // Something well above adaptive high should get 2.0
991        let scale_high = agent.surprise_scale(1.0);
992        assert!(
993            (scale_high - 2.0).abs() < 1e-12,
994            "Expected 2.0 above adaptive high: got {scale_high}"
995        );
996
997        // Something at the mean should be between 0.1 and 2.0
998        let scale_mid = agent.surprise_scale(0.24);
999        assert!(
1000            scale_mid > 0.1 && scale_mid < 2.0,
1001            "Expected interpolated value at mean, got {scale_mid}"
1002        );
1003    }
1004
1005    #[test]
1006    fn test_entropy_regularization_prevents_policy_collapse() {
1007        // With entropy regularization, repeated learning on same trajectory
1008        // should keep the policy from collapsing to a single action
1009        let mut config = default_config();
1010        config.entropy_coeff = 0.1; // Strong entropy
1011        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1012
1013        let input = vec![0.5; 9];
1014        let valid: Vec<usize> = (0..9).collect();
1015
1016        // Train many times on same trajectory
1017        for _ in 0..20 {
1018            let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
1019            let trajectory = vec![TrajectoryStep {
1020                input: input.clone(),
1021                latent_concat: infer.latent_concat,
1022                y_conv: infer.y_conv,
1023                hidden_states: infer.hidden_states,
1024                prediction_errors: infer.prediction_errors,
1025                tanh_components: infer.tanh_components,
1026                action,
1027                valid_actions: valid.clone(),
1028                reward: 1.0,
1029                surprise_score: infer.surprise_score,
1030                steps_used: infer.steps_used,
1031            }];
1032            let _ = agent.learn(&trajectory);
1033        }
1034
1035        // Check that policy is not collapsed (multiple actions selected over 50 trials)
1036        let mut seen = std::collections::HashSet::new();
1037        for _ in 0..50 {
1038            let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1039            seen.insert(action);
1040        }
1041        assert!(
1042            seen.len() > 1,
1043            "Entropy regularization should prevent collapse to single action, but only saw {:?}",
1044            seen
1045        );
1046    }
1047
1048    // ── act tests ─────────────────────────────────────────────────
1049
1050    #[test]
1051    fn test_act_returns_valid_action() {
1052        let mut agent: PcActorCritic = make_agent();
1053        let input = vec![0.5; 9];
1054        let valid = vec![1, 3, 5, 7];
1055        for _ in 0..20 {
1056            let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1057            assert!(valid.contains(&action), "Action {action} not in valid set");
1058        }
1059    }
1060
1061    #[test]
1062    #[should_panic]
1063    fn test_act_empty_valid_panics() {
1064        let mut agent: PcActorCritic = make_agent();
1065        let input = vec![0.5; 9];
1066        let _ = agent.act(&input, &[], SelectionMode::Training);
1067    }
1068
1069    // ── learning diagnostic test ──────────────────────────────
1070
1071    #[test]
1072    fn test_learn_improves_policy_for_rewarded_action() {
1073        // Linear output so logits are unbounded
1074        let config = PcActorCriticConfig {
1075            actor: PcActorConfig {
1076                input_size: 9,
1077                hidden_layers: vec![LayerDef {
1078                    size: 18,
1079                    activation: Activation::Tanh,
1080                }],
1081                output_size: 9,
1082                output_activation: Activation::Linear,
1083                alpha: 0.1,
1084                tol: 0.01,
1085                min_steps: 1,
1086                max_steps: 5,
1087                lr_weights: 0.01,
1088                synchronous: true,
1089                temperature: 1.0,
1090                local_lambda: 1.0,
1091                residual: false,
1092                rezero_init: 0.001,
1093            },
1094            critic: MlpCriticConfig {
1095                input_size: 27,
1096                hidden_layers: vec![LayerDef {
1097                    size: 36,
1098                    activation: Activation::Tanh,
1099                }],
1100                output_activation: Activation::Linear,
1101                lr: 0.005,
1102            },
1103            gamma: 0.99,
1104            surprise_low: 0.02,
1105            surprise_high: 0.15,
1106            adaptive_surprise: false,
1107            entropy_coeff: 0.0, // no entropy to isolate gradient effect
1108        };
1109        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1110
1111        let input = vec![0.0; 9];
1112        let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
1113        let target_action = 4; // center
1114
1115        // Repeatedly reward action 4
1116        for _ in 0..200 {
1117            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1118            let trajectory = vec![TrajectoryStep {
1119                input: input.clone(),
1120                latent_concat: infer.latent_concat,
1121                y_conv: infer.y_conv,
1122                hidden_states: infer.hidden_states,
1123                prediction_errors: infer.prediction_errors,
1124                tanh_components: infer.tanh_components,
1125                action: target_action,
1126                valid_actions: valid.clone(),
1127                reward: 1.0,
1128                surprise_score: infer.surprise_score,
1129                steps_used: infer.steps_used,
1130            }];
1131            agent.learn(&trajectory);
1132        }
1133
1134        // After 200 episodes always rewarding action 4, it should be the
1135        // preferred action in Play mode (deterministic argmax)
1136        let (action, infer) = agent.act(&input, &valid, SelectionMode::Play);
1137
1138        // Check that action 4's logit is the highest
1139        let logit_4 = infer.y_conv[4];
1140        let max_other = valid
1141            .iter()
1142            .filter(|&&a| a != 4)
1143            .map(|&a| infer.y_conv[a])
1144            .fold(f64::NEG_INFINITY, f64::max);
1145
1146        eprintln!(
1147            "DIAGNOSTIC: action={action}, logit[4]={logit_4:.4}, max_other={max_other:.4}, \
1148             y_conv={:?}",
1149            infer
1150                .y_conv
1151                .iter()
1152                .map(|v| format!("{v:.3}"))
1153                .collect::<Vec<_>>()
1154        );
1155
1156        assert_eq!(
1157            action, target_action,
1158            "After 200 episodes rewarding action 4, agent should prefer it. Got action {action}"
1159        );
1160    }
1161
1162    // ── config validation tests ────────────────────────────────
1163
1164    #[test]
1165    fn test_new_returns_error_zero_temperature() {
1166        let mut config = default_config();
1167        config.actor.temperature = 0.0;
1168        let err = PcActorCritic::new(config, 42)
1169            .map(|_: PcActorCritic| ())
1170            .unwrap_err();
1171        assert!(format!("{err}").contains("temperature"));
1172    }
1173
1174    #[test]
1175    fn test_new_returns_error_zero_input_size() {
1176        let mut config = default_config();
1177        config.actor.input_size = 0;
1178        config.critic.input_size = 0;
1179        assert!(PcActorCritic::new(config, 42)
1180            .map(|_: PcActorCritic| ())
1181            .is_err());
1182    }
1183
1184    #[test]
1185    fn test_new_returns_error_zero_output_size() {
1186        let mut config = default_config();
1187        config.actor.output_size = 0;
1188        assert!(PcActorCritic::new(config, 42)
1189            .map(|_: PcActorCritic| ())
1190            .is_err());
1191    }
1192
1193    #[test]
1194    fn test_new_returns_error_negative_gamma() {
1195        let mut config = default_config();
1196        config.gamma = -0.1;
1197        let err = PcActorCritic::new(config, 42)
1198            .map(|_: PcActorCritic| ())
1199            .unwrap_err();
1200        assert!(format!("{err}").contains("gamma"));
1201    }
1202
1203    // ── Phase 4 Cycle 4.1: ActivationCache construction and recording ──
1204
1205    #[test]
1206    fn test_activation_cache_new_creates_empty() {
1207        let cache: ActivationCache = ActivationCache::new(3);
1208        assert_eq!(cache.batch_size(), 0);
1209    }
1210
1211    #[test]
1212    fn test_activation_cache_record_increments_batch_size() {
1213        let mut agent: PcActorCritic = make_agent();
1214        let input = vec![0.5; 9];
1215        let valid = vec![0, 1, 2];
1216        let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1217
1218        let num_hidden = infer.hidden_states.len();
1219        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1220        cache.record(&infer.hidden_states);
1221        assert_eq!(cache.batch_size(), 1);
1222    }
1223
1224    #[test]
1225    fn test_activation_cache_record_multiple() {
1226        let mut agent: PcActorCritic = make_agent();
1227        let valid = vec![0, 1, 2];
1228        let init_input = vec![0.5; 9];
1229        let num_hidden = {
1230            let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1231            infer.hidden_states.len()
1232        };
1233
1234        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1235        for i in 0..5 {
1236            let input = vec![i as f64 * 0.1; 9];
1237            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1238            cache.record(&infer.hidden_states);
1239        }
1240        assert_eq!(cache.batch_size(), 5);
1241    }
1242
1243    #[test]
1244    fn test_activation_cache_recorded_values_match_hidden_states() {
1245        let mut agent: PcActorCritic = make_agent();
1246        let input = vec![0.5; 9];
1247        let valid = vec![0, 1, 2];
1248        let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1249
1250        let num_hidden = infer.hidden_states.len();
1251        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1252        cache.record(&infer.hidden_states);
1253
1254        // Verify recorded activations match
1255        for (layer_idx, expected) in infer.hidden_states.iter().enumerate() {
1256            let layer_data = cache.layer(layer_idx);
1257            assert_eq!(layer_data.len(), 1);
1258            assert_eq!(layer_data[0], *expected);
1259        }
1260    }
1261
1262    // ── Phase 4 Cycle 4.2: ActivationCache layer access ────────────
1263
1264    #[test]
1265    fn test_activation_cache_layer_count() {
1266        let mut agent: PcActorCritic = make_agent();
1267        let input = vec![0.5; 9];
1268        let valid = vec![0, 1, 2];
1269        let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1270
1271        let num_hidden = infer.hidden_states.len();
1272        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1273        cache.record(&infer.hidden_states);
1274
1275        assert_eq!(cache.num_layers(), num_hidden);
1276    }
1277
1278    #[test]
1279    fn test_activation_cache_layer_sample_count() {
1280        let mut agent: PcActorCritic = make_agent();
1281        let valid = vec![0, 1, 2];
1282        let init_input = vec![0.5; 9];
1283        let num_hidden = {
1284            let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1285            infer.hidden_states.len()
1286        };
1287
1288        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1289        for i in 0..10 {
1290            let input = vec![i as f64 * 0.1; 9];
1291            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1292            cache.record(&infer.hidden_states);
1293        }
1294
1295        for layer_idx in 0..num_hidden {
1296            assert_eq!(
1297                cache.layer(layer_idx).len(),
1298                10,
1299                "Layer {layer_idx} should have 10 samples"
1300            );
1301        }
1302    }
1303
1304    // ── Phase 7 Cycle 7.1: PcActorCritic::crossover ────────────
1305
1306    fn build_caches_for_agent(
1307        agent: &mut PcActorCritic,
1308        batch_size: usize,
1309    ) -> (ActivationCache, ActivationCache) {
1310        let num_actor_hidden = agent.config.actor.hidden_layers.len();
1311        let num_critic_hidden = agent.config.critic.hidden_layers.len();
1312        let mut actor_cache: ActivationCache = ActivationCache::new(num_actor_hidden);
1313        let mut critic_cache: ActivationCache = ActivationCache::new(num_critic_hidden);
1314        let valid: Vec<usize> = (0..agent.config.actor.output_size).collect();
1315        for i in 0..batch_size {
1316            let input: Vec<f64> = (0..agent.config.actor.input_size)
1317                .map(|j| ((i * 9 + j) as f64 * 0.1).sin())
1318                .collect();
1319            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1320            actor_cache.record(&infer.hidden_states);
1321            let mut critic_input = input;
1322            critic_input.extend_from_slice(&infer.latent_concat);
1323            let (_value, critic_hidden) = agent.critic.forward_with_hidden(&critic_input);
1324            critic_cache.record(&critic_hidden);
1325        }
1326        (actor_cache, critic_cache)
1327    }
1328
1329    #[test]
1330    fn test_agent_crossover_produces_valid_agent() {
1331        let config = default_config();
1332        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1333        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1334
1335        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1336        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1337
1338        let child: PcActorCritic = PcActorCritic::crossover(
1339            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1340        )
1341        .unwrap();
1342
1343        assert_eq!(
1344            child.config.actor.hidden_layers.len(),
1345            agent_a.config.actor.hidden_layers.len()
1346        );
1347    }
1348
1349    #[test]
1350    fn test_agent_crossover_actor_weights_differ() {
1351        let config = default_config();
1352        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1353        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1354
1355        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1356        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1357
1358        let child: PcActorCritic = PcActorCritic::crossover(
1359            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1360        )
1361        .unwrap();
1362
1363        assert_ne!(
1364            child.actor.layers[0].weights.data,
1365            agent_a.actor.layers[0].weights.data
1366        );
1367        assert_ne!(
1368            child.actor.layers[0].weights.data,
1369            agent_b.actor.layers[0].weights.data
1370        );
1371    }
1372
1373    #[test]
1374    fn test_agent_crossover_critic_weights_differ() {
1375        let config = default_config();
1376        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1377        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1378
1379        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1380        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1381
1382        let child: PcActorCritic = PcActorCritic::crossover(
1383            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1384        )
1385        .unwrap();
1386
1387        assert_ne!(
1388            child.critic.layers[0].weights.data,
1389            agent_a.critic.layers[0].weights.data
1390        );
1391        assert_ne!(
1392            child.critic.layers[0].weights.data,
1393            agent_b.critic.layers[0].weights.data
1394        );
1395    }
1396
1397    // ── Phase 7 Cycle 7.2: Integration — full GA workflow ───────
1398
1399    #[test]
1400    fn test_agent_crossover_child_can_infer() {
1401        let config = default_config();
1402        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1403        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1404
1405        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1406        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1407
1408        let mut child: PcActorCritic = PcActorCritic::crossover(
1409            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1410        )
1411        .unwrap();
1412
1413        let input = vec![0.5; 9];
1414        let valid = vec![0, 1, 2, 3, 4];
1415        let (action, _) = child.act(&input, &valid, SelectionMode::Training);
1416        assert!(valid.contains(&action), "Action {action} not in valid set");
1417    }
1418
1419    #[test]
1420    fn test_agent_crossover_child_can_learn() {
1421        let config = default_config();
1422        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1423        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1424
1425        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1426        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1427
1428        let mut child: PcActorCritic = PcActorCritic::crossover(
1429            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1430        )
1431        .unwrap();
1432
1433        let trajectory = make_trajectory(&mut child);
1434        let loss = child.learn(&trajectory);
1435        assert!(loss.is_finite(), "Child learn loss not finite: {loss}");
1436    }
1437
1438    #[test]
1439    fn test_agent_crossover_mismatched_batch_size_error() {
1440        let config = default_config();
1441        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1442        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1443
1444        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1445        let (ac_b, _cc_b) = build_caches_for_agent(&mut agent_b, 30); // different batch
1446        let (_, cc_b_match) = build_caches_for_agent(&mut agent_b, 50);
1447
1448        // Actor batch mismatch
1449        let result = PcActorCritic::crossover(
1450            &agent_a,
1451            &agent_b,
1452            &ac_a,
1453            &ac_b,
1454            &cc_a,
1455            &cc_b_match,
1456            0.5,
1457            config,
1458            99,
1459        );
1460        assert!(result.is_err(), "Mismatched actor batch sizes should error");
1461    }
1462
1463    // ── Fix #2: Separate critic caches in crossover ────────────
1464
1465    #[test]
1466    fn test_agent_crossover_with_separate_critic_caches() {
1467        let config = default_config();
1468        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1469        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1470
1471        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1472        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1473
1474        let child: PcActorCritic = PcActorCritic::crossover(
1475            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1476        )
1477        .unwrap();
1478
1479        assert_eq!(child.critic.layers.len(), agent_a.critic.layers.len());
1480    }
1481
1482    #[test]
1483    fn test_agent_crossover_critic_uses_own_caches() {
1484        let config = default_config();
1485        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1486        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1487
1488        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1489        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1490
1491        let child: PcActorCritic = PcActorCritic::crossover(
1492            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1493        )
1494        .unwrap();
1495
1496        assert_ne!(
1497            child.critic.layers[0].weights.data,
1498            agent_a.critic.layers[0].weights.data
1499        );
1500        assert_ne!(
1501            child.critic.layers[0].weights.data,
1502            agent_b.critic.layers[0].weights.data
1503        );
1504    }
1505
1506    #[test]
1507    fn test_agent_crossover_mismatched_critic_batch_error() {
1508        let config = default_config();
1509        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1510        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1511
1512        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1513        let (ac_b, _) = build_caches_for_agent(&mut agent_b, 50);
1514        // Build critic cache with different batch size
1515        let (_, cc_b_small) = build_caches_for_agent(&mut agent_b, 30);
1516
1517        let result = PcActorCritic::crossover(
1518            &agent_a,
1519            &agent_b,
1520            &ac_a,
1521            &ac_b,
1522            &cc_a,
1523            &cc_b_small,
1524            0.5,
1525            config,
1526            99,
1527        );
1528        assert!(
1529            result.is_err(),
1530            "Mismatched critic batch sizes should error"
1531        );
1532    }
1533
1534    // ── Phase 7 Cycle 7.3: lib.rs re-exports ────────────────────
1535
1536    #[test]
1537    fn test_activation_cache_accessible_from_crate() {
1538        // Verify ActivationCache is accessible via pc_actor_critic module
1539        let _cache: crate::pc_actor_critic::ActivationCache = ActivationCache::new(1);
1540    }
1541
1542    #[test]
1543    fn test_cca_neuron_alignment_accessible_from_crate() {
1544        // Verify cca_neuron_alignment is accessible via matrix module
1545        use crate::linalg::cpu::CpuLinAlg;
1546        use crate::linalg::LinAlg;
1547        let mat = CpuLinAlg::zeros_mat(10, 3);
1548        let _perm = crate::matrix::cca_neuron_alignment::<CpuLinAlg>(&mat, &mat).unwrap();
1549    }
1550}