Skip to main content

pc_rl_core/
pc_actor_critic.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! Integrated PC Actor-Critic agent.
6//!
7//! Combines [`PcActor`] for action selection via predictive coding inference
8//! with [`MlpCritic`] for value estimation. Supports REINFORCE episodic
9//! learning, TD(0) continuous learning, surprise-based scheduling, and
10//! entropy regularization.
11//!
12//! Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
13
14use std::collections::VecDeque;
15
16use rand::rngs::StdRng;
17use serde::{Deserialize, Serialize};
18
19use crate::error::PcError;
20use crate::linalg::cpu::CpuLinAlg;
21use crate::linalg::LinAlg;
22use crate::mlp_critic::{MlpCritic, MlpCriticConfig};
23use crate::pc_actor::{InferResult, PcActor, PcActorConfig, SelectionMode};
24
25/// Default discount factor.
26fn default_gamma() -> f64 {
27    0.95
28}
29
30/// Default surprise low threshold.
31fn default_surprise_low() -> f64 {
32    0.02
33}
34
35/// Default surprise high threshold.
36fn default_surprise_high() -> f64 {
37    0.15
38}
39
40/// Default for adaptive surprise (enabled).
41fn default_adaptive_surprise() -> bool {
42    true
43}
44
45/// Default surprise buffer size for adaptive surprise.
46fn default_surprise_buffer_size() -> usize {
47    400
48}
49
50/// Default entropy regularization coefficient.
51fn default_entropy_coeff() -> f64 {
52    0.01
53}
54
55/// Configuration for the integrated PC Actor-Critic agent.
56///
57/// # Examples
58///
59/// ```
60/// use pc_rl_core::activation::Activation;
61/// use pc_rl_core::layer::LayerDef;
62/// use pc_rl_core::mlp_critic::MlpCriticConfig;
63/// use pc_rl_core::pc_actor::PcActorConfig;
64/// use pc_rl_core::pc_actor_critic::PcActorCriticConfig;
65///
66/// let config = PcActorCriticConfig {
67///     actor: PcActorConfig {
68///         input_size: 9,
69///         hidden_layers: vec![LayerDef { size: 18, activation: Activation::Tanh }],
70///         output_size: 9,
71///         output_activation: Activation::Tanh,
72///         alpha: 0.1, tol: 0.01, min_steps: 1, max_steps: 20,
73///         lr_weights: 0.01, synchronous: true, temperature: 1.0,
74///         local_lambda: 1.0,
75///         residual: false,
76///         rezero_init: 0.001,
77///     },
78///     critic: MlpCriticConfig {
79///         input_size: 27,
80///         hidden_layers: vec![LayerDef { size: 36, activation: Activation::Tanh }],
81///         output_activation: Activation::Linear,
82///         lr: 0.005,
83///     },
84///     gamma: 0.95,
85///     surprise_low: 0.02,
86///     surprise_high: 0.15,
87///     adaptive_surprise: true,
88///     surprise_buffer_size: 400,
89///     entropy_coeff: 0.01,
90/// };
91/// ```
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct PcActorCriticConfig {
94    /// Actor (PC network) configuration.
95    pub actor: PcActorConfig,
96    /// Critic (MLP value function) configuration.
97    pub critic: MlpCriticConfig,
98    /// Discount factor for computing returns. Default: 0.95.
99    #[serde(default = "default_gamma")]
100    pub gamma: f64,
101    /// Surprise threshold below which learning rate is scaled down. Default: 0.02.
102    #[serde(default = "default_surprise_low")]
103    pub surprise_low: f64,
104    /// Surprise threshold above which learning rate is scaled up. Default: 0.15.
105    #[serde(default = "default_surprise_high")]
106    pub surprise_high: f64,
107    /// Whether to adaptively recalibrate surprise thresholds. Default: true.
108    #[serde(default = "default_adaptive_surprise")]
109    pub adaptive_surprise: bool,
110    /// Maximum number of surprise scores in the adaptive buffer.
111    /// Only used when `adaptive_surprise` is true. Default: 400.
112    #[serde(default = "default_surprise_buffer_size")]
113    pub surprise_buffer_size: usize,
114    /// Entropy regularization coefficient. Default: 0.01.
115    #[serde(default = "default_entropy_coeff")]
116    pub entropy_coeff: f64,
117}
118
119/// A single step in a trajectory collected during an episode.
120///
121/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
122#[derive(Debug, Clone)]
123pub struct TrajectoryStep<L: LinAlg = CpuLinAlg> {
124    /// Board state input vector.
125    pub input: L::Vector,
126    /// Concatenated hidden layer activations from inference.
127    pub latent_concat: L::Vector,
128    /// Converged output logits from inference.
129    pub y_conv: L::Vector,
130    /// Per-layer hidden state activations from inference (for backprop).
131    pub hidden_states: Vec<L::Vector>,
132    /// Per-layer prediction errors from the PC inference loop.
133    pub prediction_errors: Vec<L::Vector>,
134    /// Per-layer tanh components for residual layers (for correct backward pass).
135    pub tanh_components: Vec<Option<L::Vector>>,
136    /// Action taken at this step.
137    pub action: usize,
138    /// Valid actions at this step (needed for masked softmax).
139    pub valid_actions: Vec<usize>,
140    /// Reward received after taking the action.
141    pub reward: f64,
142    /// Prediction error from inference.
143    pub surprise_score: f64,
144    /// Number of inference steps used.
145    pub steps_used: usize,
146}
147
148/// Cache for hidden layer activations captured during inference.
149///
150/// Used by CCA neuron alignment during crossover to compare functional
151/// neuron responses between parent networks. Activations are recorded
152/// during normal fitness evaluation at zero additional compute cost.
153///
154/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
155///
156/// # Examples
157///
158/// ```
159/// use pc_rl_core::pc_actor_critic::ActivationCache;
160/// use pc_rl_core::linalg::cpu::CpuLinAlg;
161///
162/// let cache: ActivationCache<CpuLinAlg> = ActivationCache::new(2);
163/// assert_eq!(cache.batch_size(), 0);
164/// assert_eq!(cache.num_layers(), 2);
165/// ```
166#[derive(Debug, Clone)]
167pub struct ActivationCache<L: LinAlg = CpuLinAlg> {
168    /// activations[layer_idx][batch_sample_idx] = neuron activations.
169    layers: Vec<Vec<L::Vector>>,
170}
171
172impl<L: LinAlg> ActivationCache<L> {
173    /// Creates a new empty cache with the given number of hidden layers.
174    ///
175    /// # Arguments
176    ///
177    /// * `num_layers` - Number of hidden layers to track.
178    pub fn new(num_layers: usize) -> Self {
179        Self {
180            layers: (0..num_layers).map(|_| Vec::new()).collect(),
181        }
182    }
183
184    /// Returns the number of samples recorded so far.
185    pub fn batch_size(&self) -> usize {
186        self.layers.first().map_or(0, |l| l.len())
187    }
188
189    /// Returns the number of hidden layers in the cache.
190    pub fn num_layers(&self) -> usize {
191        self.layers.len()
192    }
193
194    /// Records hidden layer activations from a single inference.
195    ///
196    /// # Arguments
197    ///
198    /// * `hidden_states` - Per-layer activations from `InferResult::hidden_states`.
199    pub fn record(&mut self, hidden_states: &[L::Vector]) {
200        for (layer, state) in self.layers.iter_mut().zip(hidden_states.iter()) {
201            layer.push(state.clone());
202        }
203    }
204
205    /// Returns the recorded activations for a given layer.
206    ///
207    /// # Arguments
208    ///
209    /// * `layer_idx` - Index of the hidden layer.
210    pub fn layer(&self, layer_idx: usize) -> &[L::Vector] {
211        &self.layers[layer_idx]
212    }
213}
214
215/// Integrated PC Actor-Critic agent.
216///
217/// Combines a predictive coding actor with an MLP critic for
218/// reinforcement learning with surprise-based scheduling.
219///
220/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
221#[derive(Debug)]
222pub struct PcActorCritic<L: LinAlg = CpuLinAlg> {
223    /// The PC actor network.
224    pub(crate) actor: PcActor<L>,
225    /// The MLP critic (value function).
226    pub(crate) critic: MlpCritic<L>,
227    /// Agent configuration.
228    pub config: PcActorCriticConfig,
229    /// Random number generator for action selection.
230    rng: StdRng,
231    /// Circular buffer of recent surprise scores for adaptive thresholds.
232    surprise_buffer: VecDeque<f64>,
233}
234
235impl<L: LinAlg> PcActorCritic<L> {
236    /// Creates a new PC Actor-Critic agent.
237    ///
238    /// # Arguments
239    ///
240    /// * `config` - Agent configuration with actor, critic, and learning parameters.
241    /// * `seed` - Random seed for reproducibility.
242    /// # Errors
243    ///
244    /// Returns `PcError::ConfigValidation` if gamma is out of `[0.0, 1.0]`,
245    /// or if actor/critic config is invalid.
246    pub fn new(config: PcActorCriticConfig, seed: u64) -> Result<Self, PcError> {
247        if !(0.0..=1.0).contains(&config.gamma) {
248            return Err(PcError::ConfigValidation(format!(
249                "gamma must be in [0.0, 1.0], got {}",
250                config.gamma
251            )));
252        }
253        if config.adaptive_surprise && config.surprise_buffer_size < 10 {
254            return Err(PcError::ConfigValidation(format!(
255                "surprise_buffer_size must be >= 10 when adaptive_surprise is enabled, got {}",
256                config.surprise_buffer_size
257            )));
258        }
259
260        use rand::SeedableRng;
261        let mut rng = StdRng::seed_from_u64(seed);
262        let actor = PcActor::<L>::new(config.actor.clone(), &mut rng)?;
263        let critic = MlpCritic::<L>::new(config.critic.clone(), &mut rng)?;
264        Ok(Self {
265            actor,
266            critic,
267            config,
268            rng,
269            surprise_buffer: VecDeque::new(),
270        })
271    }
272
273    /// Creates a child agent by crossing over two parent agents using CCA neuron alignment.
274    ///
275    /// Delegates to `PcActor::crossover` and `MlpCritic::crossover`, converting
276    /// activation caches to the matrix format expected by CCA alignment.
277    ///
278    /// # Arguments
279    ///
280    /// * `parent_a` - First parent agent (reference, typically higher fitness).
281    /// * `parent_b` - Second parent agent.
282    /// * `cache_a` - Activation cache for parent A on the reference batch.
283    /// * `cache_b` - Activation cache for parent B on the reference batch.
284    /// * `alpha` - Blending weight: 1.0 = all A, 0.0 = all B.
285    /// * `child_config` - Configuration for the child agent.
286    /// * `seed` - Random seed for the child's RNG.
287    ///
288    /// # Errors
289    ///
290    /// Returns `PcError::DimensionMismatch` if activation caches have different
291    /// batch sizes. Returns `PcError::ConfigValidation` if child config is invalid.
292    #[allow(clippy::too_many_arguments)]
293    pub fn crossover(
294        parent_a: &PcActorCritic<L>,
295        parent_b: &PcActorCritic<L>,
296        actor_cache_a: &ActivationCache<L>,
297        actor_cache_b: &ActivationCache<L>,
298        critic_cache_a: &ActivationCache<L>,
299        critic_cache_b: &ActivationCache<L>,
300        alpha: f64,
301        child_config: PcActorCriticConfig,
302        seed: u64,
303    ) -> Result<Self, PcError> {
304        // Validate actor batch sizes match
305        if actor_cache_a.batch_size() != actor_cache_b.batch_size() {
306            return Err(PcError::DimensionMismatch {
307                expected: actor_cache_a.batch_size(),
308                got: actor_cache_b.batch_size(),
309                context: "actor activation cache batch sizes must match for crossover",
310            });
311        }
312        // Validate critic batch sizes match
313        if critic_cache_a.batch_size() != critic_cache_b.batch_size() {
314            return Err(PcError::DimensionMismatch {
315                expected: critic_cache_a.batch_size(),
316                got: critic_cache_b.batch_size(),
317                context: "critic activation cache batch sizes must match for crossover",
318            });
319        }
320
321        // Convert caches to matrices [batch × neurons] for CCA
322        let actor_cache_mats_a = cache_to_matrices::<L>(actor_cache_a);
323        let actor_cache_mats_b = cache_to_matrices::<L>(actor_cache_b);
324        let critic_cache_mats_a = cache_to_matrices::<L>(critic_cache_a);
325        let critic_cache_mats_b = cache_to_matrices::<L>(critic_cache_b);
326
327        use rand::SeedableRng;
328        let mut rng = StdRng::seed_from_u64(seed);
329
330        // Crossover actor with actor-specific caches
331        let actor = PcActor::<L>::crossover(
332            &parent_a.actor,
333            &parent_b.actor,
334            &actor_cache_mats_a,
335            &actor_cache_mats_b,
336            alpha,
337            child_config.actor.clone(),
338            &mut rng,
339        )?;
340
341        // Crossover critic with critic-specific caches
342        let critic = MlpCritic::<L>::crossover(
343            &parent_a.critic,
344            &parent_b.critic,
345            &critic_cache_mats_a,
346            &critic_cache_mats_b,
347            alpha,
348            child_config.critic.clone(),
349            &mut rng,
350        )?;
351
352        Ok(Self {
353            actor,
354            critic,
355            config: child_config,
356            rng,
357            surprise_buffer: VecDeque::new(),
358        })
359    }
360
361    /// Reconstructs an agent from pre-built components (used by serializer).
362    ///
363    /// # Arguments
364    ///
365    /// * `config` - Agent configuration.
366    /// * `actor` - Pre-built PC actor with loaded weights.
367    /// * `critic` - Pre-built MLP critic with loaded weights.
368    /// * `rng` - Random number generator.
369    pub fn from_parts(
370        config: PcActorCriticConfig,
371        actor: PcActor<L>,
372        critic: MlpCritic<L>,
373        rng: StdRng,
374    ) -> Self {
375        Self {
376            actor,
377            critic,
378            config,
379            rng,
380            surprise_buffer: VecDeque::new(),
381        }
382    }
383
384    /// Runs PC inference without selecting an action or modifying RNG state.
385    ///
386    /// Use this when you only need the inference result (e.g., for TD(0)
387    /// next-state evaluation) without side effects.
388    ///
389    /// # Arguments
390    ///
391    /// * `input` - Board state vector.
392    ///
393    /// # Panics
394    ///
395    /// Panics if `input.len() != config.actor.input_size`.
396    pub fn infer(&self, input: &[f64]) -> InferResult<L> {
397        self.actor.infer(input)
398    }
399
400    /// Selects an action given the current state.
401    ///
402    /// Runs PC inference on the input, then selects an action using the
403    /// converged logits and the specified selection mode.
404    ///
405    /// # Arguments
406    ///
407    /// * `input` - Board state vector.
408    /// * `valid_actions` - Indices of legal actions.
409    /// * `mode` - Training (stochastic) or Play (deterministic).
410    ///
411    /// # Panics
412    ///
413    /// Panics if `valid_actions` is empty.
414    pub fn act(
415        &mut self,
416        input: &[f64],
417        valid_actions: &[usize],
418        mode: SelectionMode,
419    ) -> (usize, InferResult<L>) {
420        let infer_result = self.actor.infer(input);
421        let action =
422            self.actor
423                .select_action(&infer_result.y_conv, valid_actions, mode, &mut self.rng);
424        (action, infer_result)
425    }
426
427    /// Learns from a complete episode trajectory using REINFORCE with baseline.
428    ///
429    /// Empty trajectory returns 0.0 without modifying weights. Otherwise computes
430    /// discounted returns, advantages, and updates both actor and critic.
431    ///
432    /// # Arguments
433    ///
434    /// * `trajectory` - Sequence of steps from an episode.
435    ///
436    /// # Returns
437    ///
438    /// Average critic loss over the trajectory.
439    pub fn learn(&mut self, trajectory: &[TrajectoryStep<L>]) -> f64 {
440        if trajectory.is_empty() {
441            return 0.0;
442        }
443
444        let n = trajectory.len();
445
446        // Compute discounted returns backward
447        let mut returns = vec![0.0; n];
448        returns[n - 1] = trajectory[n - 1].reward;
449        for t in (0..n - 1).rev() {
450            returns[t] = trajectory[t].reward + self.config.gamma * returns[t + 1];
451        }
452
453        let mut total_loss = 0.0;
454
455        for (t, step) in trajectory.iter().enumerate() {
456            // Build critic input: concat(input, latent_concat)
457            let input_vec = L::vec_to_vec(&step.input);
458            let latent_vec = L::vec_to_vec(&step.latent_concat);
459            let mut critic_input = input_vec.clone();
460            critic_input.extend_from_slice(&latent_vec);
461
462            // V(s)
463            let value = self.critic.forward(&critic_input);
464            let advantage = returns[t] - value;
465
466            // Update critic toward discounted return
467            let loss = self.critic.update(&critic_input, returns[t]);
468            total_loss += loss;
469
470            // Policy gradient
471            let y_conv_vec = L::vec_to_vec(&step.y_conv);
472            let scaled: Vec<f64> = y_conv_vec
473                .iter()
474                .map(|&v| v / self.actor.config.temperature)
475                .collect();
476            let scaled_l = L::vec_from_slice(&scaled);
477            let pi_l = L::softmax_masked(&scaled_l, &step.valid_actions);
478            let pi = L::vec_to_vec(&pi_l);
479
480            let mut delta = vec![0.0; pi.len()];
481            for &i in &step.valid_actions {
482                delta[i] = pi[i];
483            }
484            delta[step.action] -= 1.0;
485
486            // Scale by advantage
487            for &i in &step.valid_actions {
488                delta[i] *= advantage;
489            }
490
491            // Entropy regularization
492            for &i in &step.valid_actions {
493                let log_pi = (pi[i].max(1e-10)).ln();
494                delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
495            }
496
497            // Compute surprise scale and update actor using stored hidden_states
498            let s_scale = self.surprise_scale(step.surprise_score);
499
500            let stored_infer = InferResult {
501                y_conv: step.y_conv.clone(),
502                latent_concat: step.latent_concat.clone(),
503                hidden_states: step.hidden_states.clone(),
504                prediction_errors: step.prediction_errors.clone(),
505                surprise_score: step.surprise_score,
506                steps_used: step.steps_used,
507                converged: false,
508                tanh_components: step.tanh_components.clone(),
509            };
510            self.actor
511                .update_weights(&delta, &stored_infer, &input_vec, s_scale);
512
513            // Push surprise to adaptive buffer
514            if self.config.adaptive_surprise {
515                self.push_surprise(step.surprise_score);
516            }
517        }
518
519        total_loss / n as f64
520    }
521
522    /// Single-step TD(0) continuous learning.
523    ///
524    /// # Arguments
525    ///
526    /// * `input` - Current state.
527    /// * `infer` - Inference result from `act` at current state.
528    /// * `action` - Action taken.
529    /// * `valid_actions` - Valid actions at current state.
530    /// * `reward` - Reward received.
531    /// * `next_input` - Next state.
532    /// * `next_infer` - Inference result from `act` at next state.
533    /// * `terminal` - Whether the episode ended.
534    ///
535    /// # Returns
536    ///
537    /// Critic loss for this step.
538    #[allow(clippy::too_many_arguments)]
539    pub fn learn_continuous(
540        &mut self,
541        input: &[f64],
542        infer: &InferResult<L>,
543        action: usize,
544        valid_actions: &[usize],
545        reward: f64,
546        next_input: &[f64],
547        next_infer: &InferResult<L>,
548        terminal: bool,
549    ) -> f64 {
550        // Build critic inputs
551        let latent_vec = L::vec_to_vec(&infer.latent_concat);
552        let mut critic_input = input.to_vec();
553        critic_input.extend_from_slice(&latent_vec);
554
555        let next_latent_vec = L::vec_to_vec(&next_infer.latent_concat);
556        let mut next_critic_input = next_input.to_vec();
557        next_critic_input.extend_from_slice(&next_latent_vec);
558
559        let v_s = self.critic.forward(&critic_input);
560        let v_next = if terminal {
561            0.0
562        } else {
563            self.critic.forward(&next_critic_input)
564        };
565
566        let target = reward
567            + if terminal {
568                0.0
569            } else {
570                self.config.gamma * v_next
571            };
572        let td_error = target - v_s;
573
574        // Update critic
575        let loss = self.critic.update(&critic_input, target);
576
577        // Policy gradient (same formula as learn, but scaled by td_error)
578        let y_conv_vec = L::vec_to_vec(&infer.y_conv);
579        let scaled: Vec<f64> = y_conv_vec
580            .iter()
581            .map(|&v| v / self.actor.config.temperature)
582            .collect();
583        let scaled_l = L::vec_from_slice(&scaled);
584        let pi_l = L::softmax_masked(&scaled_l, valid_actions);
585        let pi = L::vec_to_vec(&pi_l);
586
587        let mut delta = vec![0.0; pi.len()];
588        for &i in valid_actions {
589            delta[i] = pi[i];
590        }
591        delta[action] -= 1.0;
592
593        for &i in valid_actions {
594            delta[i] *= td_error;
595        }
596
597        // Entropy regularization
598        for &i in valid_actions {
599            let log_pi = (pi[i].max(1e-10)).ln();
600            delta[i] -= self.config.entropy_coeff * (log_pi + 1.0);
601        }
602
603        let s_scale = self.surprise_scale(infer.surprise_score);
604        self.actor.update_weights(&delta, infer, input, s_scale);
605
606        if self.config.adaptive_surprise {
607            self.push_surprise(infer.surprise_score);
608        }
609
610        loss
611    }
612
613    /// Computes the learning rate scale factor based on surprise score.
614    ///
615    /// - surprise <= low → 0.1
616    /// - surprise >= high → 2.0
617    /// - Between → linear interpolation from 0.1 to 2.0
618    ///
619    /// If adaptive surprise is enabled and the buffer has >= 10 entries,
620    /// thresholds are dynamically recomputed from the buffer statistics.
621    pub fn surprise_scale(&self, surprise: f64) -> f64 {
622        let (low, high) = if self.config.adaptive_surprise && self.surprise_buffer.len() >= 10 {
623            let mean = self.surprise_buffer.iter().sum::<f64>() / self.surprise_buffer.len() as f64;
624            let variance = self
625                .surprise_buffer
626                .iter()
627                .map(|&s| (s - mean) * (s - mean))
628                .sum::<f64>()
629                / self.surprise_buffer.len() as f64;
630            let std = variance.sqrt();
631            let lo = (mean - 0.5 * std).max(0.0);
632            let hi = mean + 1.5 * std;
633            (lo, hi)
634        } else {
635            (self.config.surprise_low, self.config.surprise_high)
636        };
637
638        if surprise <= low {
639            0.1
640        } else if surprise >= high {
641            2.0
642        } else {
643            // Linear interpolation
644            let t = (surprise - low) / (high - low);
645            0.1 + t * (2.0 - 0.1)
646        }
647    }
648
649    /// Pushes a surprise score into the adaptive buffer (circular).
650    fn push_surprise(&mut self, surprise: f64) {
651        if self.surprise_buffer.len() >= self.config.surprise_buffer_size {
652            self.surprise_buffer.pop_front();
653        }
654        self.surprise_buffer.push_back(surprise);
655    }
656}
657
658/// Converts an `ActivationCache` into a vector of matrices `[batch × neurons]`,
659/// one per hidden layer, suitable for CCA alignment.
660fn cache_to_matrices<L: LinAlg>(cache: &ActivationCache<L>) -> Vec<L::Matrix> {
661    let num_layers = cache.num_layers();
662    let batch_size = cache.batch_size();
663    let mut matrices = Vec::with_capacity(num_layers);
664
665    for layer_idx in 0..num_layers {
666        let samples = cache.layer(layer_idx);
667        if samples.is_empty() {
668            matrices.push(L::zeros_mat(0, 0));
669            continue;
670        }
671        let n_neurons = L::vec_len(&samples[0]);
672        let mut mat = L::zeros_mat(batch_size, n_neurons);
673        for (r, sample) in samples.iter().enumerate() {
674            for c in 0..n_neurons {
675                L::mat_set(&mut mat, r, c, L::vec_get(sample, c));
676            }
677        }
678        matrices.push(mat);
679    }
680
681    matrices
682}
683
684#[cfg(test)]
685mod tests {
686    use super::*;
687    use crate::activation::Activation;
688    use crate::layer::LayerDef;
689    use crate::pc_actor::SelectionMode;
690
691    fn default_config() -> PcActorCriticConfig {
692        PcActorCriticConfig {
693            actor: PcActorConfig {
694                input_size: 9,
695                hidden_layers: vec![LayerDef {
696                    size: 18,
697                    activation: Activation::Tanh,
698                }],
699                output_size: 9,
700                output_activation: Activation::Tanh,
701                alpha: 0.1,
702                tol: 0.01,
703                min_steps: 1,
704                max_steps: 20,
705                lr_weights: 0.01,
706                synchronous: true,
707                temperature: 1.0,
708                local_lambda: 1.0,
709                residual: false,
710                rezero_init: 0.001,
711            },
712            critic: MlpCriticConfig {
713                input_size: 27,
714                hidden_layers: vec![LayerDef {
715                    size: 36,
716                    activation: Activation::Tanh,
717                }],
718                output_activation: Activation::Linear,
719                lr: 0.005,
720            },
721            gamma: 0.95,
722            surprise_low: 0.02,
723            surprise_high: 0.15,
724            adaptive_surprise: false,
725            surprise_buffer_size: 100,
726            entropy_coeff: 0.01,
727        }
728    }
729
730    fn make_agent() -> PcActorCritic {
731        let agent: PcActorCritic = PcActorCritic::new(default_config(), 42).unwrap();
732        agent
733    }
734
735    fn make_trajectory(agent: &mut PcActorCritic) -> Vec<TrajectoryStep> {
736        let input = vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5];
737        let valid = vec![2, 7];
738        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
739        vec![TrajectoryStep {
740            input,
741            latent_concat: infer.latent_concat,
742            y_conv: infer.y_conv,
743            hidden_states: infer.hidden_states,
744            prediction_errors: infer.prediction_errors,
745            tanh_components: infer.tanh_components,
746            action,
747            valid_actions: valid,
748            reward: 1.0,
749            surprise_score: infer.surprise_score,
750            steps_used: infer.steps_used,
751        }]
752    }
753
754    // ── learn tests ───────────────────────────────────────────────
755
756    #[test]
757    fn test_learn_empty_returns_zero_without_modifying_weights() {
758        let mut agent: PcActorCritic = make_agent();
759        let w_before = agent.actor.layers[0].weights.data.clone();
760        let cw_before = agent.critic.layers[0].weights.data.clone();
761        let loss = agent.learn(&[]);
762        assert_eq!(loss, 0.0);
763        assert_eq!(agent.actor.layers[0].weights.data, w_before);
764        assert_eq!(agent.critic.layers[0].weights.data, cw_before);
765    }
766
767    #[test]
768    fn test_learn_updates_actor_weights() {
769        let mut agent: PcActorCritic = make_agent();
770        let trajectory = make_trajectory(&mut agent);
771        let w_before = agent.actor.layers[0].weights.data.clone();
772        let _ = agent.learn(&trajectory);
773        assert_ne!(agent.actor.layers[0].weights.data, w_before);
774    }
775
776    #[test]
777    fn test_learn_updates_critic_weights() {
778        let mut agent: PcActorCritic = make_agent();
779        let trajectory = make_trajectory(&mut agent);
780        let w_before = agent.critic.layers[0].weights.data.clone();
781        let _ = agent.learn(&trajectory);
782        assert_ne!(agent.critic.layers[0].weights.data, w_before);
783    }
784
785    #[test]
786    fn test_learn_returns_finite_nonneg_loss() {
787        let mut agent: PcActorCritic = make_agent();
788        let trajectory = make_trajectory(&mut agent);
789        let loss = agent.learn(&trajectory);
790        assert!(loss.is_finite(), "Loss {loss} is not finite");
791        assert!(loss >= 0.0, "Loss {loss} is negative");
792    }
793
794    #[test]
795    fn test_learn_single_step_trajectory() {
796        let mut agent: PcActorCritic = make_agent();
797        let input = vec![0.5; 9];
798        let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
799        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
800        let trajectory = vec![TrajectoryStep {
801            input,
802            latent_concat: infer.latent_concat,
803            y_conv: infer.y_conv,
804            hidden_states: infer.hidden_states,
805            prediction_errors: infer.prediction_errors,
806            tanh_components: infer.tanh_components,
807            action,
808            valid_actions: valid,
809            reward: -1.0,
810            surprise_score: infer.surprise_score,
811            steps_used: infer.steps_used,
812        }];
813        let loss = agent.learn(&trajectory);
814        assert!(loss.is_finite());
815    }
816
817    #[test]
818    fn test_learn_multi_step_uses_stored_hidden_states() {
819        // Build a 3-step trajectory to exercise multi-step learning
820        let mut agent: PcActorCritic = make_agent();
821        let inputs = [
822            vec![1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0, 0.0, 0.5],
823            vec![0.5, 0.5, -1.0, 0.0, 1.0, -0.5, 0.0, -1.0, 0.5],
824            vec![-1.0, 0.0, 1.0, -0.5, 0.5, 0.0, 1.0, -1.0, -0.5],
825        ];
826        let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
827
828        let mut trajectory = Vec::new();
829        for (i, inp) in inputs.iter().enumerate() {
830            let (action, infer) = agent.act(inp, &valid, SelectionMode::Training);
831            trajectory.push(TrajectoryStep {
832                input: inp.clone(),
833                latent_concat: infer.latent_concat,
834                y_conv: infer.y_conv,
835                hidden_states: infer.hidden_states,
836                prediction_errors: infer.prediction_errors,
837                tanh_components: infer.tanh_components,
838                action,
839                valid_actions: valid.clone(),
840                reward: if i == 2 { 1.0 } else { 0.0 },
841                surprise_score: infer.surprise_score,
842                steps_used: infer.steps_used,
843            });
844        }
845
846        let loss = agent.learn(&trajectory);
847        assert!(
848            loss.is_finite(),
849            "Multi-step learn should produce finite loss"
850        );
851        assert!(loss >= 0.0);
852    }
853
854    // ── learn_continuous tests ────────────────────────────────────
855
856    #[test]
857    fn test_learn_continuous_nonterminal_uses_next_value() {
858        let mut agent: PcActorCritic = make_agent();
859        let input = vec![0.5; 9];
860        let next_input = vec![-0.5; 9];
861        let valid = vec![0, 1, 2];
862        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
863        let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
864
865        // Non-terminal: should incorporate next value
866        let loss = agent.learn_continuous(
867            &input,
868            &infer,
869            action,
870            &valid,
871            0.5,
872            &next_input,
873            &next_infer,
874            false,
875        );
876        assert!(loss.is_finite());
877    }
878
879    #[test]
880    fn test_learn_continuous_terminal_uses_reward_only() {
881        let mut agent: PcActorCritic = make_agent();
882        let input = vec![0.5; 9];
883        let next_input = vec![0.0; 9];
884        let valid = vec![0, 1, 2];
885        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
886        let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
887
888        // Terminal: target = reward only (no gamma * V(s'))
889        let loss = agent.learn_continuous(
890            &input,
891            &infer,
892            action,
893            &valid,
894            1.0,
895            &next_input,
896            &next_infer,
897            true,
898        );
899        assert!(loss.is_finite());
900    }
901
902    #[test]
903    fn test_learn_continuous_terminal_and_nonterminal_produce_different_updates() {
904        // Create two identical agents
905        let config = default_config();
906        let mut agent_term: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
907        let mut agent_nonterm: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
908
909        let input = vec![0.5; 9];
910        let next_input = vec![-0.5; 9];
911        let valid = vec![0, 1, 2];
912
913        // Use identical actions and inferences
914        let (action, infer) = agent_term.act(&input, &valid, SelectionMode::Training);
915        let (_, next_infer) = agent_term.act(&next_input, &valid, SelectionMode::Training);
916
917        // Clone infer for the non-terminal agent (same starting point)
918        let (action2, infer2) = agent_nonterm.act(&input, &valid, SelectionMode::Training);
919        let (_, next_infer2) = agent_nonterm.act(&next_input, &valid, SelectionMode::Training);
920
921        // Terminal update
922        let loss_term = agent_term.learn_continuous(
923            &input,
924            &infer,
925            action,
926            &valid,
927            1.0,
928            &next_input,
929            &next_infer,
930            true,
931        );
932
933        // Non-terminal update with same reward
934        let loss_nonterm = agent_nonterm.learn_continuous(
935            &input,
936            &infer2,
937            action2,
938            &valid,
939            1.0,
940            &next_input,
941            &next_infer2,
942            false,
943        );
944
945        // The losses should differ because terminal uses target=reward
946        // while non-terminal uses target=reward+gamma*V(s')
947        assert!(
948            (loss_term - loss_nonterm).abs() > 1e-15,
949            "Terminal and non-terminal should produce different losses: {loss_term} vs {loss_nonterm}"
950        );
951    }
952
953    #[test]
954    fn test_learn_continuous_updates_actor() {
955        let mut agent: PcActorCritic = make_agent();
956        let input = vec![0.5; 9];
957        let next_input = vec![-0.5; 9];
958        let valid = vec![0, 1, 2];
959        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
960        let (_, next_infer) = agent.act(&next_input, &valid, SelectionMode::Training);
961        let w_before = agent.actor.layers[0].weights.data.clone();
962        let _ = agent.learn_continuous(
963            &input,
964            &infer,
965            action,
966            &valid,
967            1.0,
968            &next_input,
969            &next_infer,
970            false,
971        );
972        assert_ne!(agent.actor.layers[0].weights.data, w_before);
973    }
974
975    // ── surprise_scale tests ─────────────────────────────────────
976
977    #[test]
978    fn test_surprise_scale_below_low() {
979        let agent: PcActorCritic = make_agent();
980        let scale = agent.surprise_scale(0.01); // below low=0.02
981        assert!((scale - 0.1).abs() < 1e-12, "Expected 0.1, got {scale}");
982    }
983
984    #[test]
985    fn test_surprise_scale_above_high() {
986        let agent: PcActorCritic = make_agent();
987        let scale = agent.surprise_scale(0.20); // above high=0.15
988        assert!((scale - 2.0).abs() < 1e-12, "Expected 2.0, got {scale}");
989    }
990
991    #[test]
992    fn test_surprise_scale_midpoint_in_range() {
993        let agent: PcActorCritic = make_agent();
994        let midpoint = (0.02 + 0.15) / 2.0;
995        let scale = agent.surprise_scale(midpoint);
996        assert!(
997            scale > 0.1 && scale < 2.0,
998            "Midpoint scale {scale} out of range"
999        );
1000    }
1001
1002    #[test]
1003    fn test_surprise_scale_monotone_increasing() {
1004        let agent: PcActorCritic = make_agent();
1005        let s1 = agent.surprise_scale(0.01);
1006        let s2 = agent.surprise_scale(0.05);
1007        let s3 = agent.surprise_scale(0.10);
1008        let s4 = agent.surprise_scale(0.20);
1009        assert!(s1 <= s2, "s1={s1} > s2={s2}");
1010        assert!(s2 <= s3, "s2={s2} > s3={s3}");
1011        assert!(s3 <= s4, "s3={s3} > s4={s4}");
1012    }
1013
1014    #[test]
1015    fn test_adaptive_surprise_recalibrates_thresholds_after_many_episodes() {
1016        let mut config = default_config();
1017        config.adaptive_surprise = true;
1018        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1019
1020        // Fill buffer with varied surprise scores to get nonzero std
1021        for i in 0..15 {
1022            agent.push_surprise(0.1 + 0.02 * i as f64);
1023        }
1024
1025        // mean ≈ 0.24, std ≈ 0.089
1026        // adaptive low = max(0, mean - 0.5*std) ≈ 0.196
1027        // adaptive high = mean + 1.5*std ≈ 0.373
1028        // These differ from the static defaults (0.02, 0.15)
1029
1030        // Something well below adaptive low should get 0.1
1031        let scale_low = agent.surprise_scale(0.0);
1032        assert!(
1033            (scale_low - 0.1).abs() < 1e-12,
1034            "Expected 0.1 below adaptive low: got {scale_low}"
1035        );
1036
1037        // Something well above adaptive high should get 2.0
1038        let scale_high = agent.surprise_scale(1.0);
1039        assert!(
1040            (scale_high - 2.0).abs() < 1e-12,
1041            "Expected 2.0 above adaptive high: got {scale_high}"
1042        );
1043
1044        // Something at the mean should be between 0.1 and 2.0
1045        let scale_mid = agent.surprise_scale(0.24);
1046        assert!(
1047            scale_mid > 0.1 && scale_mid < 2.0,
1048            "Expected interpolated value at mean, got {scale_mid}"
1049        );
1050    }
1051
1052    #[test]
1053    fn test_entropy_regularization_prevents_policy_collapse() {
1054        // With entropy regularization, repeated learning on same trajectory
1055        // should keep the policy from collapsing to a single action
1056        let mut config = default_config();
1057        config.entropy_coeff = 0.1; // Strong entropy
1058        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1059
1060        let input = vec![0.5; 9];
1061        let valid: Vec<usize> = (0..9).collect();
1062
1063        // Train many times on same trajectory
1064        for _ in 0..20 {
1065            let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
1066            let trajectory = vec![TrajectoryStep {
1067                input: input.clone(),
1068                latent_concat: infer.latent_concat,
1069                y_conv: infer.y_conv,
1070                hidden_states: infer.hidden_states,
1071                prediction_errors: infer.prediction_errors,
1072                tanh_components: infer.tanh_components,
1073                action,
1074                valid_actions: valid.clone(),
1075                reward: 1.0,
1076                surprise_score: infer.surprise_score,
1077                steps_used: infer.steps_used,
1078            }];
1079            let _ = agent.learn(&trajectory);
1080        }
1081
1082        // Check that policy is not collapsed (multiple actions selected over 50 trials)
1083        let mut seen = std::collections::HashSet::new();
1084        for _ in 0..50 {
1085            let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1086            seen.insert(action);
1087        }
1088        assert!(
1089            seen.len() > 1,
1090            "Entropy regularization should prevent collapse to single action, but only saw {:?}",
1091            seen
1092        );
1093    }
1094
1095    // ── act tests ─────────────────────────────────────────────────
1096
1097    #[test]
1098    fn test_act_returns_valid_action() {
1099        let mut agent: PcActorCritic = make_agent();
1100        let input = vec![0.5; 9];
1101        let valid = vec![1, 3, 5, 7];
1102        for _ in 0..20 {
1103            let (action, _) = agent.act(&input, &valid, SelectionMode::Training);
1104            assert!(valid.contains(&action), "Action {action} not in valid set");
1105        }
1106    }
1107
1108    #[test]
1109    #[should_panic]
1110    fn test_act_empty_valid_panics() {
1111        let mut agent: PcActorCritic = make_agent();
1112        let input = vec![0.5; 9];
1113        let _ = agent.act(&input, &[], SelectionMode::Training);
1114    }
1115
1116    // ── learning diagnostic test ──────────────────────────────
1117
1118    #[test]
1119    fn test_learn_improves_policy_for_rewarded_action() {
1120        // Linear output so logits are unbounded
1121        let config = PcActorCriticConfig {
1122            actor: PcActorConfig {
1123                input_size: 9,
1124                hidden_layers: vec![LayerDef {
1125                    size: 18,
1126                    activation: Activation::Tanh,
1127                }],
1128                output_size: 9,
1129                output_activation: Activation::Linear,
1130                alpha: 0.1,
1131                tol: 0.01,
1132                min_steps: 1,
1133                max_steps: 5,
1134                lr_weights: 0.01,
1135                synchronous: true,
1136                temperature: 1.0,
1137                local_lambda: 1.0,
1138                residual: false,
1139                rezero_init: 0.001,
1140            },
1141            critic: MlpCriticConfig {
1142                input_size: 27,
1143                hidden_layers: vec![LayerDef {
1144                    size: 36,
1145                    activation: Activation::Tanh,
1146                }],
1147                output_activation: Activation::Linear,
1148                lr: 0.005,
1149            },
1150            gamma: 0.99,
1151            surprise_low: 0.02,
1152            surprise_high: 0.15,
1153            adaptive_surprise: false,
1154            surprise_buffer_size: 100,
1155            entropy_coeff: 0.0, // no entropy to isolate gradient effect
1156        };
1157        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
1158
1159        let input = vec![0.0; 9];
1160        let valid = vec![0, 1, 2, 3, 4, 5, 6, 7, 8];
1161        let target_action = 4; // center
1162
1163        // Repeatedly reward action 4
1164        for _ in 0..200 {
1165            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1166            let trajectory = vec![TrajectoryStep {
1167                input: input.clone(),
1168                latent_concat: infer.latent_concat,
1169                y_conv: infer.y_conv,
1170                hidden_states: infer.hidden_states,
1171                prediction_errors: infer.prediction_errors,
1172                tanh_components: infer.tanh_components,
1173                action: target_action,
1174                valid_actions: valid.clone(),
1175                reward: 1.0,
1176                surprise_score: infer.surprise_score,
1177                steps_used: infer.steps_used,
1178            }];
1179            agent.learn(&trajectory);
1180        }
1181
1182        // After 200 episodes always rewarding action 4, it should be the
1183        // preferred action in Play mode (deterministic argmax)
1184        let (action, infer) = agent.act(&input, &valid, SelectionMode::Play);
1185
1186        // Check that action 4's logit is the highest
1187        let logit_4 = infer.y_conv[4];
1188        let max_other = valid
1189            .iter()
1190            .filter(|&&a| a != 4)
1191            .map(|&a| infer.y_conv[a])
1192            .fold(f64::NEG_INFINITY, f64::max);
1193
1194        eprintln!(
1195            "DIAGNOSTIC: action={action}, logit[4]={logit_4:.4}, max_other={max_other:.4}, \
1196             y_conv={:?}",
1197            infer
1198                .y_conv
1199                .iter()
1200                .map(|v| format!("{v:.3}"))
1201                .collect::<Vec<_>>()
1202        );
1203
1204        assert_eq!(
1205            action, target_action,
1206            "After 200 episodes rewarding action 4, agent should prefer it. Got action {action}"
1207        );
1208    }
1209
1210    // ── config validation tests ────────────────────────────────
1211
1212    #[test]
1213    fn test_new_returns_error_zero_temperature() {
1214        let mut config = default_config();
1215        config.actor.temperature = 0.0;
1216        let err = PcActorCritic::new(config, 42)
1217            .map(|_: PcActorCritic| ())
1218            .unwrap_err();
1219        assert!(format!("{err}").contains("temperature"));
1220    }
1221
1222    #[test]
1223    fn test_new_returns_error_zero_input_size() {
1224        let mut config = default_config();
1225        config.actor.input_size = 0;
1226        config.critic.input_size = 0;
1227        assert!(PcActorCritic::new(config, 42)
1228            .map(|_: PcActorCritic| ())
1229            .is_err());
1230    }
1231
1232    #[test]
1233    fn test_new_returns_error_zero_output_size() {
1234        let mut config = default_config();
1235        config.actor.output_size = 0;
1236        assert!(PcActorCritic::new(config, 42)
1237            .map(|_: PcActorCritic| ())
1238            .is_err());
1239    }
1240
1241    #[test]
1242    fn test_new_returns_error_negative_gamma() {
1243        let mut config = default_config();
1244        config.gamma = -0.1;
1245        let err = PcActorCritic::new(config, 42)
1246            .map(|_: PcActorCritic| ())
1247            .unwrap_err();
1248        assert!(format!("{err}").contains("gamma"));
1249    }
1250
1251    #[test]
1252    fn test_new_returns_error_surprise_buffer_size_zero() {
1253        let mut config = default_config();
1254        config.adaptive_surprise = true;
1255        config.surprise_buffer_size = 0;
1256        let result = PcActorCritic::new(config, 42).map(|_: PcActorCritic| ());
1257        assert!(result.is_err());
1258        let err = result.unwrap_err();
1259        assert!(
1260            format!("{err}").contains("surprise_buffer_size"),
1261            "Expected surprise_buffer_size error, got: {err}"
1262        );
1263    }
1264
1265    // ── Phase 4 Cycle 4.1: ActivationCache construction and recording ──
1266
1267    #[test]
1268    fn test_activation_cache_new_creates_empty() {
1269        let cache: ActivationCache = ActivationCache::new(3);
1270        assert_eq!(cache.batch_size(), 0);
1271    }
1272
1273    #[test]
1274    fn test_activation_cache_record_increments_batch_size() {
1275        let mut agent: PcActorCritic = make_agent();
1276        let input = vec![0.5; 9];
1277        let valid = vec![0, 1, 2];
1278        let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1279
1280        let num_hidden = infer.hidden_states.len();
1281        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1282        cache.record(&infer.hidden_states);
1283        assert_eq!(cache.batch_size(), 1);
1284    }
1285
1286    #[test]
1287    fn test_activation_cache_record_multiple() {
1288        let mut agent: PcActorCritic = make_agent();
1289        let valid = vec![0, 1, 2];
1290        let init_input = vec![0.5; 9];
1291        let num_hidden = {
1292            let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1293            infer.hidden_states.len()
1294        };
1295
1296        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1297        for i in 0..5 {
1298            let input = vec![i as f64 * 0.1; 9];
1299            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1300            cache.record(&infer.hidden_states);
1301        }
1302        assert_eq!(cache.batch_size(), 5);
1303    }
1304
1305    #[test]
1306    fn test_activation_cache_recorded_values_match_hidden_states() {
1307        let mut agent: PcActorCritic = make_agent();
1308        let input = vec![0.5; 9];
1309        let valid = vec![0, 1, 2];
1310        let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1311
1312        let num_hidden = infer.hidden_states.len();
1313        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1314        cache.record(&infer.hidden_states);
1315
1316        // Verify recorded activations match
1317        for (layer_idx, expected) in infer.hidden_states.iter().enumerate() {
1318            let layer_data = cache.layer(layer_idx);
1319            assert_eq!(layer_data.len(), 1);
1320            assert_eq!(layer_data[0], *expected);
1321        }
1322    }
1323
1324    // ── Phase 4 Cycle 4.2: ActivationCache layer access ────────────
1325
1326    #[test]
1327    fn test_activation_cache_layer_count() {
1328        let mut agent: PcActorCritic = make_agent();
1329        let input = vec![0.5; 9];
1330        let valid = vec![0, 1, 2];
1331        let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1332
1333        let num_hidden = infer.hidden_states.len();
1334        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1335        cache.record(&infer.hidden_states);
1336
1337        assert_eq!(cache.num_layers(), num_hidden);
1338    }
1339
1340    #[test]
1341    fn test_activation_cache_layer_sample_count() {
1342        let mut agent: PcActorCritic = make_agent();
1343        let valid = vec![0, 1, 2];
1344        let init_input = vec![0.5; 9];
1345        let num_hidden = {
1346            let (_, infer) = agent.act(&init_input, &valid, SelectionMode::Training);
1347            infer.hidden_states.len()
1348        };
1349
1350        let mut cache: ActivationCache = ActivationCache::new(num_hidden);
1351        for i in 0..10 {
1352            let input = vec![i as f64 * 0.1; 9];
1353            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1354            cache.record(&infer.hidden_states);
1355        }
1356
1357        for layer_idx in 0..num_hidden {
1358            assert_eq!(
1359                cache.layer(layer_idx).len(),
1360                10,
1361                "Layer {layer_idx} should have 10 samples"
1362            );
1363        }
1364    }
1365
1366    // ── Phase 7 Cycle 7.1: PcActorCritic::crossover ────────────
1367
1368    fn build_caches_for_agent(
1369        agent: &mut PcActorCritic,
1370        batch_size: usize,
1371    ) -> (ActivationCache, ActivationCache) {
1372        let num_actor_hidden = agent.config.actor.hidden_layers.len();
1373        let num_critic_hidden = agent.config.critic.hidden_layers.len();
1374        let mut actor_cache: ActivationCache = ActivationCache::new(num_actor_hidden);
1375        let mut critic_cache: ActivationCache = ActivationCache::new(num_critic_hidden);
1376        let valid: Vec<usize> = (0..agent.config.actor.output_size).collect();
1377        for i in 0..batch_size {
1378            let input: Vec<f64> = (0..agent.config.actor.input_size)
1379                .map(|j| ((i * 9 + j) as f64 * 0.1).sin())
1380                .collect();
1381            let (_, infer) = agent.act(&input, &valid, SelectionMode::Training);
1382            actor_cache.record(&infer.hidden_states);
1383            let mut critic_input = input;
1384            critic_input.extend_from_slice(&infer.latent_concat);
1385            let (_value, critic_hidden) = agent.critic.forward_with_hidden(&critic_input);
1386            critic_cache.record(&critic_hidden);
1387        }
1388        (actor_cache, critic_cache)
1389    }
1390
1391    #[test]
1392    fn test_agent_crossover_produces_valid_agent() {
1393        let config = default_config();
1394        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1395        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1396
1397        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1398        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1399
1400        let child: PcActorCritic = PcActorCritic::crossover(
1401            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1402        )
1403        .unwrap();
1404
1405        assert_eq!(
1406            child.config.actor.hidden_layers.len(),
1407            agent_a.config.actor.hidden_layers.len()
1408        );
1409    }
1410
1411    #[test]
1412    fn test_agent_crossover_actor_weights_differ() {
1413        let config = default_config();
1414        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1415        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1416
1417        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1418        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1419
1420        let child: PcActorCritic = PcActorCritic::crossover(
1421            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1422        )
1423        .unwrap();
1424
1425        assert_ne!(
1426            child.actor.layers[0].weights.data,
1427            agent_a.actor.layers[0].weights.data
1428        );
1429        assert_ne!(
1430            child.actor.layers[0].weights.data,
1431            agent_b.actor.layers[0].weights.data
1432        );
1433    }
1434
1435    #[test]
1436    fn test_agent_crossover_critic_weights_differ() {
1437        let config = default_config();
1438        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1439        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1440
1441        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1442        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1443
1444        let child: PcActorCritic = PcActorCritic::crossover(
1445            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1446        )
1447        .unwrap();
1448
1449        assert_ne!(
1450            child.critic.layers[0].weights.data,
1451            agent_a.critic.layers[0].weights.data
1452        );
1453        assert_ne!(
1454            child.critic.layers[0].weights.data,
1455            agent_b.critic.layers[0].weights.data
1456        );
1457    }
1458
1459    // ── Phase 7 Cycle 7.2: Integration — full GA workflow ───────
1460
1461    #[test]
1462    fn test_agent_crossover_child_can_infer() {
1463        let config = default_config();
1464        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1465        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1466
1467        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1468        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1469
1470        let mut child: PcActorCritic = PcActorCritic::crossover(
1471            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1472        )
1473        .unwrap();
1474
1475        let input = vec![0.5; 9];
1476        let valid = vec![0, 1, 2, 3, 4];
1477        let (action, _) = child.act(&input, &valid, SelectionMode::Training);
1478        assert!(valid.contains(&action), "Action {action} not in valid set");
1479    }
1480
1481    #[test]
1482    fn test_agent_crossover_child_can_learn() {
1483        let config = default_config();
1484        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1485        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1486
1487        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1488        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1489
1490        let mut child: PcActorCritic = PcActorCritic::crossover(
1491            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1492        )
1493        .unwrap();
1494
1495        let trajectory = make_trajectory(&mut child);
1496        let loss = child.learn(&trajectory);
1497        assert!(loss.is_finite(), "Child learn loss not finite: {loss}");
1498    }
1499
1500    #[test]
1501    fn test_agent_crossover_mismatched_batch_size_error() {
1502        let config = default_config();
1503        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1504        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1505
1506        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1507        let (ac_b, _cc_b) = build_caches_for_agent(&mut agent_b, 30); // different batch
1508        let (_, cc_b_match) = build_caches_for_agent(&mut agent_b, 50);
1509
1510        // Actor batch mismatch
1511        let result = PcActorCritic::crossover(
1512            &agent_a,
1513            &agent_b,
1514            &ac_a,
1515            &ac_b,
1516            &cc_a,
1517            &cc_b_match,
1518            0.5,
1519            config,
1520            99,
1521        );
1522        assert!(result.is_err(), "Mismatched actor batch sizes should error");
1523    }
1524
1525    // ── Fix #2: Separate critic caches in crossover ────────────
1526
1527    #[test]
1528    fn test_agent_crossover_with_separate_critic_caches() {
1529        let config = default_config();
1530        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1531        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1532
1533        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1534        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1535
1536        let child: PcActorCritic = PcActorCritic::crossover(
1537            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1538        )
1539        .unwrap();
1540
1541        assert_eq!(child.critic.layers.len(), agent_a.critic.layers.len());
1542    }
1543
1544    #[test]
1545    fn test_agent_crossover_critic_uses_own_caches() {
1546        let config = default_config();
1547        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1548        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1549
1550        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1551        let (ac_b, cc_b) = build_caches_for_agent(&mut agent_b, 50);
1552
1553        let child: PcActorCritic = PcActorCritic::crossover(
1554            &agent_a, &agent_b, &ac_a, &ac_b, &cc_a, &cc_b, 0.5, config, 99,
1555        )
1556        .unwrap();
1557
1558        assert_ne!(
1559            child.critic.layers[0].weights.data,
1560            agent_a.critic.layers[0].weights.data
1561        );
1562        assert_ne!(
1563            child.critic.layers[0].weights.data,
1564            agent_b.critic.layers[0].weights.data
1565        );
1566    }
1567
1568    #[test]
1569    fn test_agent_crossover_mismatched_critic_batch_error() {
1570        let config = default_config();
1571        let mut agent_a: PcActorCritic = PcActorCritic::new(config.clone(), 42).unwrap();
1572        let mut agent_b: PcActorCritic = PcActorCritic::new(config.clone(), 123).unwrap();
1573
1574        let (ac_a, cc_a) = build_caches_for_agent(&mut agent_a, 50);
1575        let (ac_b, _) = build_caches_for_agent(&mut agent_b, 50);
1576        // Build critic cache with different batch size
1577        let (_, cc_b_small) = build_caches_for_agent(&mut agent_b, 30);
1578
1579        let result = PcActorCritic::crossover(
1580            &agent_a,
1581            &agent_b,
1582            &ac_a,
1583            &ac_b,
1584            &cc_a,
1585            &cc_b_small,
1586            0.5,
1587            config,
1588            99,
1589        );
1590        assert!(
1591            result.is_err(),
1592            "Mismatched critic batch sizes should error"
1593        );
1594    }
1595
1596    // ── Phase 7 Cycle 7.3: lib.rs re-exports ────────────────────
1597
1598    #[test]
1599    fn test_activation_cache_accessible_from_crate() {
1600        // Verify ActivationCache is accessible via pc_actor_critic module
1601        let _cache: crate::pc_actor_critic::ActivationCache = ActivationCache::new(1);
1602    }
1603
1604    #[test]
1605    fn test_cca_neuron_alignment_accessible_from_crate() {
1606        // Verify cca_neuron_alignment is accessible via matrix module
1607        use crate::linalg::cpu::CpuLinAlg;
1608        use crate::linalg::LinAlg;
1609        let mat = CpuLinAlg::zeros_mat(10, 3);
1610        let _perm = crate::matrix::cca_neuron_alignment::<CpuLinAlg>(&mat, &mat).unwrap();
1611    }
1612}