Skip to main content

pc_rl_core/
pc_actor.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! Predictive Coding Actor Network.
6//!
7//! Implements an actor that uses iterative top-down/bottom-up predictive coding
8//! inference loops instead of standard feedforward passes. The prediction error
9//! (surprise score) drives learning rate modulation in the actor-critic agent.
10
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14use crate::activation::Activation;
15use crate::error::PcError;
16use crate::layer::{Layer, LayerDef};
17use crate::linalg::cpu::CpuLinAlg;
18use crate::linalg::LinAlg;
19
20/// Configuration for the predictive coding actor network.
21///
22/// # Examples
23///
24/// ```
25/// use pc_rl_core::activation::Activation;
26/// use pc_rl_core::layer::LayerDef;
27/// use pc_rl_core::pc_actor::PcActorConfig;
28///
29/// let config = PcActorConfig {
30///     input_size: 9,
31///     hidden_layers: vec![LayerDef { size: 18, activation: Activation::Tanh }],
32///     output_size: 9,
33///     output_activation: Activation::Tanh,
34///     alpha: 0.1,
35///     tol: 0.01,
36///     min_steps: 1,
37///     max_steps: 20,
38///     lr_weights: 0.01,
39///     synchronous: true,
40///     temperature: 1.0,
41///     local_lambda: 1.0,
42///     residual: false,
43///     rezero_init: 0.001,
44/// };
45/// ```
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct PcActorConfig {
48    /// Number of input features (e.g. 9 for tic-tac-toe board).
49    pub input_size: usize,
50    /// Hidden layer topology definitions.
51    pub hidden_layers: Vec<LayerDef>,
52    /// Number of output actions.
53    pub output_size: usize,
54    /// Activation function for the output layer.
55    pub output_activation: Activation,
56    /// Inference learning rate for PC loop state updates (`h += alpha * error`).
57    /// Set to 0.0 to disable PC inference (network behaves as standard MLP).
58    /// Active regardless of `residual` setting.
59    pub alpha: f64,
60    /// Convergence threshold for RMS prediction error.
61    /// PC loop exits early when surprise < tol (after at least `min_steps`).
62    /// Active regardless of `residual` setting.
63    pub tol: f64,
64    /// Minimum PC inference steps before convergence check is allowed.
65    /// Active regardless of `residual` setting.
66    pub min_steps: usize,
67    /// Maximum PC inference steps per action.
68    /// Active regardless of `residual` setting.
69    pub max_steps: usize,
70    /// Base learning rate for weight updates.
71    pub lr_weights: f64,
72    /// If true, use synchronous snapshot mode; otherwise in-place.
73    pub synchronous: bool,
74    /// Softmax temperature for action selection.
75    pub temperature: f64,
76    /// Blend factor for hidden layer weight updates, range `[0.0, 1.0]`.
77    ///
78    /// Controls how hidden layers combine two gradient signals:
79    /// `delta = lambda * backprop_grad + (1 - lambda) * pc_prediction_error`
80    ///
81    /// - `1.0` — Pure backprop: reward signal propagated from output (default).
82    /// - `0.0` — Pure local PC: prediction errors from inference loop
83    ///   used as gradients (Millidge et al. 2022). No vanishing gradient
84    ///   but no reward signal reaches hidden layers.
85    /// - `0.0 < lambda < 1.0` — Hybrid: reward-aware backprop regularized
86    ///   by local PC consistency errors.
87    ///
88    /// The output layer always uses standard backprop regardless of this value.
89    #[serde(default = "default_local_lambda")]
90    pub local_lambda: f64,
91    /// Enable residual skip connections between same-dimension hidden layers.
92    /// When false, `rezero_init` is ignored. When true, all hidden layers
93    /// must have the same size, and skip connections with learnable ReZero
94    /// scaling are added between consecutive hidden layers (not the first,
95    /// since input_size typically differs from hidden_size).
96    #[serde(default)]
97    pub residual: bool,
98    /// Initial value for ReZero scaling factors on residual connections.
99    /// Only used when `residual = true`. Controls initial contribution of
100    /// the nonlinear component: `h[i] = rezero_init * tanh(...) + h[i-1]`.
101    ///
102    /// - `0.001` — Near-identity start (ReZero: network learns depth gradually)
103    /// - `1.0` — Standard ResNet residual (full contribution from start)
104    ///
105    /// Ignored when `residual = false`.
106    #[serde(default = "default_rezero_init")]
107    pub rezero_init: f64,
108}
109
110/// Default rezero_init: 0.001 (near-identity at start).
111fn default_rezero_init() -> f64 {
112    0.001
113}
114
115/// Default local_lambda: 1.0 (pure backprop).
116fn default_local_lambda() -> f64 {
117    1.0
118}
119
120/// Result of the predictive coding inference loop.
121///
122/// Contains converged output logits, hidden state representations,
123/// and diagnostic information about the inference process.
124///
125/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
126#[derive(Debug, Clone)]
127pub struct InferResult<L: LinAlg = CpuLinAlg> {
128    /// Converged output logits.
129    pub y_conv: L::Vector,
130    /// All hidden states concatenated (fed to critic).
131    pub latent_concat: L::Vector,
132    /// Per-layer hidden state activations.
133    pub hidden_states: Vec<L::Vector>,
134    /// Per-layer prediction errors from the last PC inference step.
135    /// Ordered from top hidden layer to bottom (reverse layer order).
136    pub prediction_errors: Vec<L::Vector>,
137    /// RMS prediction error across layers.
138    pub surprise_score: f64,
139    /// Number of inference steps performed.
140    pub steps_used: usize,
141    /// Whether the inference loop converged within tolerance.
142    pub converged: bool,
143    /// Per-layer tanh components for residual layers.
144    /// `None` for non-skip layers, `Some(tanh_out)` for skip-eligible layers.
145    /// Needed for correct backward pass (derivative on tanh_out, not full h\[i\]).
146    pub tanh_components: Vec<Option<L::Vector>>,
147}
148
149/// Action selection mode.
150#[derive(Debug, Clone, Copy, PartialEq, Eq)]
151pub enum SelectionMode {
152    /// Stochastic sampling from softmax distribution.
153    Training,
154    /// Deterministic argmax selection.
155    Play,
156}
157
158/// Predictive coding actor network.
159///
160/// Uses iterative top-down/bottom-up inference loops to produce
161/// stable hidden representations and output logits.
162///
163/// Generic over a [`LinAlg`] backend `L`. Defaults to [`CpuLinAlg`].
164///
165/// # Examples
166///
167/// ```
168/// use pc_rl_core::activation::Activation;
169/// use pc_rl_core::layer::LayerDef;
170/// use pc_rl_core::pc_actor::{PcActor, PcActorConfig, SelectionMode};
171/// use rand::SeedableRng;
172/// use rand::rngs::StdRng;
173///
174/// let config = PcActorConfig {
175///     input_size: 9,
176///     hidden_layers: vec![LayerDef { size: 18, activation: Activation::Tanh }],
177///     output_size: 9,
178///     output_activation: Activation::Tanh,
179///     alpha: 0.1, tol: 0.01, min_steps: 1, max_steps: 20,
180///     lr_weights: 0.01, synchronous: true, temperature: 1.0,
181///     local_lambda: 1.0,
182///     residual: false,
183///     rezero_init: 0.001,
184/// };
185/// let mut rng = StdRng::seed_from_u64(42);
186/// let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
187/// let result = actor.infer(&[0.0; 9]);
188/// assert_eq!(result.y_conv.len(), 9);
189/// ```
190#[derive(Debug)]
191pub struct PcActor<L: LinAlg = CpuLinAlg> {
192    /// Network layers: hidden_layers.len() + 1 (output layer).
193    pub(crate) layers: Vec<Layer<L>>,
194    /// Actor configuration.
195    pub config: PcActorConfig,
196    /// ReZero scaling factors for skip connections. One per skip layer (all i >= 1 when residual=true).
197    pub(crate) rezero_alpha: Vec<f64>,
198    /// Projection matrices for skip connections between layers of different sizes.
199    /// One entry per skip layer: `None` for identity (same size), `Some(Matrix)` for projection.
200    pub(crate) skip_projections: Vec<Option<L::Matrix>>,
201}
202
203impl<L: LinAlg> PcActor<L> {
204    /// Creates a new PC actor with Xavier-initialized layers.
205    ///
206    /// # Arguments
207    ///
208    /// * `config` - Actor configuration specifying topology and hyperparameters.
209    /// * `rng` - Random number generator for weight initialization.
210    ///
211    /// # Errors
212    ///
213    /// Returns `PcError::ConfigValidation` if `input_size`, `output_size`,
214    /// or `temperature` are invalid.
215    pub fn new(config: PcActorConfig, rng: &mut impl Rng) -> Result<Self, PcError> {
216        if config.input_size == 0 {
217            return Err(PcError::ConfigValidation("input_size must be > 0".into()));
218        }
219        if config.output_size == 0 {
220            return Err(PcError::ConfigValidation("output_size must be > 0".into()));
221        }
222        if config.temperature <= 0.0 {
223            return Err(PcError::ConfigValidation(format!(
224                "temperature must be positive, got {}",
225                config.temperature
226            )));
227        }
228        if !(0.0..=1.0).contains(&config.local_lambda) {
229            return Err(PcError::ConfigValidation(format!(
230                "local_lambda must be in [0.0, 1.0], got {}",
231                config.local_lambda
232            )));
233        }
234        if config.rezero_init < 0.0 {
235            return Err(PcError::ConfigValidation(format!(
236                "rezero_init must be >= 0, got {}",
237                config.rezero_init
238            )));
239        }
240        let mut layers: Vec<Layer<L>> = Vec::new();
241        let mut prev_size = config.input_size;
242
243        for def in &config.hidden_layers {
244            layers.push(Layer::<L>::new(prev_size, def.size, def.activation, rng));
245            prev_size = def.size;
246        }
247
248        // Output layer
249        layers.push(Layer::<L>::new(
250            prev_size,
251            config.output_size,
252            config.output_activation,
253            rng,
254        ));
255
256        // Compute rezero_alpha and skip_projections: one per skip layer (all i >= 1)
257        let (rezero_alpha, skip_projections) = if config.residual {
258            let mut alphas = Vec::new();
259            let mut projs = Vec::new();
260            for i in 1..config.hidden_layers.len() {
261                alphas.push(config.rezero_init);
262                if config.hidden_layers[i].size != config.hidden_layers[i - 1].size {
263                    projs.push(Some(L::xavier_mat(
264                        config.hidden_layers[i].size,
265                        config.hidden_layers[i - 1].size,
266                        rng,
267                    )));
268                } else {
269                    projs.push(None);
270                }
271            }
272            (alphas, projs)
273        } else {
274            (Vec::new(), Vec::new())
275        };
276
277        Ok(Self {
278            layers,
279            config,
280            rezero_alpha,
281            skip_projections,
282        })
283    }
284
285    /// Returns the total size of the latent concatenation (sum of hidden layer sizes).
286    pub fn latent_size(&self) -> usize {
287        self.config.hidden_layers.iter().map(|def| def.size).sum()
288    }
289
290    /// Runs the predictive coding inference loop on the given input.
291    ///
292    /// This method is `&self` — it never modifies weights.
293    ///
294    /// # Arguments
295    ///
296    /// * `input` - Input vector of length `input_size`.
297    ///
298    /// # Panics
299    ///
300    /// Panics if `input.len() != config.input_size`.
301    /// Returns whether hidden layer `i` has a skip connection (identity or projection).
302    fn is_skip_layer(&self, i: usize) -> bool {
303        self.config.residual && i >= 1
304    }
305
306    /// Returns the rezero_alpha/skip_projections index for hidden layer `i`.
307    fn skip_alpha_index(&self, i: usize) -> Option<usize> {
308        if !self.is_skip_layer(i) {
309            return None;
310        }
311        Some(i - 1)
312    }
313
314    pub fn infer(&self, input: &[f64]) -> InferResult<L> {
315        assert_eq!(
316            input.len(),
317            self.config.input_size,
318            "input size mismatch: got {}, expected {}",
319            input.len(),
320            self.config.input_size
321        );
322
323        let input_vec = L::vec_from_slice(input);
324        let n_hidden = self.config.hidden_layers.len();
325
326        // Forward pass to initialize hidden states and output
327        let mut hidden_states: Vec<L::Vector> = Vec::with_capacity(n_hidden);
328        let mut tanh_components: Vec<Option<L::Vector>> = Vec::with_capacity(n_hidden);
329        let mut prev = input_vec.clone();
330        for (i, layer) in self.layers[..n_hidden].iter().enumerate() {
331            let tanh_out = layer.forward(&prev);
332            if let Some(alpha_idx) = self.skip_alpha_index(i) {
333                let alpha = self.rezero_alpha[alpha_idx];
334                let scaled = L::vec_scale(&tanh_out, alpha);
335                let skip_path = if let Some(ref proj) = self.skip_projections[alpha_idx] {
336                    L::mat_vec_mul(proj, &prev)
337                } else {
338                    prev.clone()
339                };
340                prev = L::vec_add(&skip_path, &scaled);
341                tanh_components.push(Some(tanh_out));
342            } else {
343                prev = tanh_out;
344                tanh_components.push(None);
345            }
346            hidden_states.push(prev.clone());
347        }
348        // Output from last hidden (or input if no hidden)
349        let last_input = if n_hidden > 0 {
350            &hidden_states[n_hidden - 1]
351        } else {
352            &input_vec
353        };
354        let mut y = self.layers[n_hidden].forward(last_input);
355
356        // PC inference loop
357        let mut steps_used = 0;
358        let mut converged = false;
359        let mut surprise_score = 0.0;
360        let mut last_errors: Vec<L::Vector> = Vec::new();
361
362        for step in 0..self.config.max_steps {
363            steps_used = step + 1;
364
365            if self.config.synchronous {
366                // Snapshot mode: freeze all states
367                let snapshot: Vec<L::Vector> = hidden_states.clone();
368                let tanh_snap: Vec<Option<L::Vector>> = tanh_components.clone();
369
370                let mut error_vecs: Vec<L::Vector> = Vec::new();
371
372                for i in (0..n_hidden).rev() {
373                    // For top-down prediction, use tanh_component of layer above
374                    // (not the full residual sum) when it is a skip layer.
375                    let state_above = if i == n_hidden - 1 {
376                        &y
377                    } else if let Some(ref tc) = tanh_snap[i + 1] {
378                        tc
379                    } else {
380                        &snapshot[i + 1]
381                    };
382
383                    // Top-down prediction targets tanh_component for skip layers
384                    let target = if let Some(ref tc) = tanh_snap[i] {
385                        tc
386                    } else {
387                        &snapshot[i]
388                    };
389
390                    let prediction = self.layers[i + 1]
391                        .transpose_forward(state_above, self.config.hidden_layers[i].activation);
392
393                    let error = L::vec_sub(&prediction, target);
394                    error_vecs.push(error.clone());
395
396                    // Update tanh_component or hidden_state
397                    let updated_target =
398                        L::vec_add(target, &L::vec_scale(&error, self.config.alpha));
399                    if let Some(alpha_idx) = self.skip_alpha_index(i) {
400                        tanh_components[i] = Some(updated_target.clone());
401                        let alpha = self.rezero_alpha[alpha_idx];
402                        let prev_h = if i > 0 {
403                            &hidden_states[i - 1]
404                        } else {
405                            &input_vec
406                        };
407                        let skip_path = if let Some(ref proj) = self.skip_projections[alpha_idx] {
408                            L::mat_vec_mul(proj, prev_h)
409                        } else {
410                            prev_h.clone()
411                        };
412                        hidden_states[i] =
413                            L::vec_add(&skip_path, &L::vec_scale(&updated_target, alpha));
414                    } else {
415                        hidden_states[i] = updated_target;
416                    }
417                }
418
419                let top_hidden = if n_hidden > 0 {
420                    &hidden_states[n_hidden - 1]
421                } else {
422                    &input_vec
423                };
424                y = self.layers[n_hidden].forward(top_hidden);
425
426                let refs: Vec<&L::Vector> = error_vecs.iter().collect();
427                surprise_score = L::rms_error(&refs);
428                last_errors = error_vecs;
429            } else {
430                // In-place mode: updates immediately visible
431                let mut error_vecs: Vec<L::Vector> = Vec::new();
432
433                for i in (0..n_hidden).rev() {
434                    // For top-down prediction, use tanh_component of layer above
435                    // (not the full residual sum) when it is a skip layer.
436                    let state_above = if i == n_hidden - 1 {
437                        &y
438                    } else if let Some(ref tc) = tanh_components[i + 1] {
439                        tc
440                    } else {
441                        &hidden_states[i + 1]
442                    };
443
444                    let target = if let Some(ref tc) = tanh_components[i] {
445                        tc.clone()
446                    } else {
447                        hidden_states[i].clone()
448                    };
449
450                    let prediction = self.layers[i + 1]
451                        .transpose_forward(state_above, self.config.hidden_layers[i].activation);
452
453                    let error = L::vec_sub(&prediction, &target);
454                    error_vecs.push(error.clone());
455
456                    let updated_target =
457                        L::vec_add(&target, &L::vec_scale(&error, self.config.alpha));
458                    if let Some(alpha_idx) = self.skip_alpha_index(i) {
459                        tanh_components[i] = Some(updated_target.clone());
460                        let alpha = self.rezero_alpha[alpha_idx];
461                        let prev_h = if i > 0 {
462                            &hidden_states[i - 1]
463                        } else {
464                            &input_vec
465                        };
466                        let skip_path = if let Some(ref proj) = self.skip_projections[alpha_idx] {
467                            L::mat_vec_mul(proj, prev_h)
468                        } else {
469                            prev_h.clone()
470                        };
471                        hidden_states[i] =
472                            L::vec_add(&skip_path, &L::vec_scale(&updated_target, alpha));
473                    } else {
474                        hidden_states[i] = updated_target;
475                    }
476                }
477
478                let top_hidden = if n_hidden > 0 {
479                    &hidden_states[n_hidden - 1]
480                } else {
481                    &input_vec
482                };
483                y = self.layers[n_hidden].forward(top_hidden);
484
485                let refs: Vec<&L::Vector> = error_vecs.iter().collect();
486                surprise_score = L::rms_error(&refs);
487                last_errors = error_vecs;
488            }
489
490            // Convergence check (alpha must be > 0 for meaningful convergence)
491            if self.config.alpha > 0.0
492                && step + 1 >= self.config.min_steps
493                && surprise_score < self.config.tol
494            {
495                converged = true;
496                break;
497            }
498        }
499
500        // Build latent_concat (uses vec_to_vec for GPU compatibility)
501        let mut latent_raw: Vec<f64> = Vec::new();
502        for h in &hidden_states {
503            latent_raw.extend_from_slice(&L::vec_to_vec(h));
504        }
505        let latent_concat = L::vec_from_slice(&latent_raw);
506
507        InferResult {
508            y_conv: y,
509            latent_concat,
510            hidden_states,
511            prediction_errors: last_errors,
512            surprise_score,
513            steps_used,
514            converged,
515            tanh_components,
516        }
517    }
518
519    /// Selects an action given converged output logits and valid actions.
520    ///
521    /// # Arguments
522    ///
523    /// * `y_conv` - Output logits from inference.
524    /// * `valid_actions` - Indices of valid actions.
525    /// * `mode` - Training (stochastic) or Play (deterministic).
526    /// * `rng` - Random number generator (used only in Training mode).
527    ///
528    /// # Panics
529    ///
530    /// Panics if `valid_actions` is empty.
531    pub fn select_action(
532        &self,
533        y_conv: &L::Vector,
534        valid_actions: &[usize],
535        mode: SelectionMode,
536        rng: &mut impl Rng,
537    ) -> usize {
538        assert!(!valid_actions.is_empty(), "valid_actions must not be empty");
539
540        // Scale logits by temperature
541        let scaled = L::vec_scale(y_conv, 1.0 / self.config.temperature);
542
543        let probs = L::softmax_masked(&scaled, valid_actions);
544
545        match mode {
546            SelectionMode::Play => L::argmax_masked(&probs, valid_actions),
547            SelectionMode::Training => L::sample_from_probs(&probs, valid_actions, rng),
548        }
549    }
550
551    /// Updates network weights using a blend of backprop and local PC error.
552    ///
553    /// The `local_lambda` config controls the blend: 1.0 = pure backprop,
554    /// 0.0 = pure local PC learning (Millidge et al. 2022), intermediate = hybrid.
555    ///
556    /// # Arguments
557    ///
558    /// * `output_delta` - Error signal at the output layer.
559    /// * `infer_result` - Result from the most recent inference.
560    /// * `input` - Original input that was fed to `infer`.
561    /// * `surprise_scale` - Multiplier on learning rate based on surprise.
562    ///
563    /// # Panics
564    ///
565    /// Panics if `input.len() != config.input_size`.
566    pub fn update_weights(
567        &mut self,
568        output_delta: &[f64],
569        infer_result: &InferResult<L>,
570        input: &[f64],
571        surprise_scale: f64,
572    ) {
573        assert_eq!(
574            input.len(),
575            self.config.input_size,
576            "input size mismatch: got {}, expected {}",
577            input.len(),
578            self.config.input_size
579        );
580
581        self.update_weights_hybrid(
582            output_delta,
583            infer_result,
584            input,
585            surprise_scale,
586            self.config.local_lambda,
587        );
588    }
589
590    /// Hybrid weight update blending backprop and local PC error signals.
591    ///
592    /// For hidden layers, the effective delta is:
593    /// `delta = lambda * backprop_delta + (1 - lambda) * pc_error`
594    ///
595    /// * `lambda = 1.0` → pure backprop (standard mode).
596    /// * `lambda = 0.0` → pure local PC learning (Millidge et al. 2022).
597    /// * `0 < lambda < 1` → hybrid blend.
598    ///
599    /// The output layer always uses standard backprop from `output_delta`.
600    fn update_weights_hybrid(
601        &mut self,
602        output_delta: &[f64],
603        infer_result: &InferResult<L>,
604        input: &[f64],
605        surprise_scale: f64,
606        lambda: f64,
607    ) {
608        let input_vec = L::vec_from_slice(input);
609        let output_delta_vec = L::vec_from_slice(output_delta);
610        let n_hidden = self.config.hidden_layers.len();
611        let n_layers = self.layers.len();
612
613        // Output layer: always standard backward
614        let output_input = if n_hidden > 0 {
615            &infer_result.hidden_states[n_hidden - 1]
616        } else {
617            &input_vec
618        };
619        let output_output = &infer_result.y_conv;
620        let mut bp_delta = self.layers[n_layers - 1].backward(
621            output_input,
622            output_output,
623            &output_delta_vec,
624            self.config.lr_weights,
625            surprise_scale,
626        );
627
628        // Hidden layers (from top to bottom)
629        for i in (0..n_hidden).rev() {
630            let layer_input = if i > 0 {
631                &infer_result.hidden_states[i - 1]
632            } else {
633                &input_vec
634            };
635
636            // Blend backprop delta with local PC error
637            let effective_delta = if (lambda - 1.0).abs() < f64::EPSILON {
638                bp_delta.clone()
639            } else if lambda.abs() < f64::EPSILON {
640                let error_idx = n_hidden - 1 - i;
641                infer_result.prediction_errors[error_idx].clone()
642            } else {
643                let error_idx = n_hidden - 1 - i;
644                let pc_error = &infer_result.prediction_errors[error_idx];
645                let bp_scaled = L::vec_scale(&bp_delta, lambda);
646                let pc_scaled = L::vec_scale(pc_error, 1.0 - lambda);
647                L::vec_add(&bp_scaled, &pc_scaled)
648            };
649
650            if let Some(alpha_idx) = self.skip_alpha_index(i) {
651                // Skip-eligible layer: use tanh_out for derivative, scale by alpha,
652                // add identity path to propagated gradient, update alpha.
653                let tanh_out = infer_result.tanh_components[i].as_ref().unwrap();
654                let alpha = self.rezero_alpha[alpha_idx];
655                let effective_lr = self.config.lr_weights * surprise_scale;
656
657                // Scale delta by rezero_alpha for the nonlinear path
658                let scaled_delta = L::vec_scale(&effective_delta, alpha);
659
660                // Backward through the layer using tanh_out (not hidden_states[i])
661                let propagated = self.layers[i].backward(
662                    layer_input,
663                    tanh_out,
664                    &scaled_delta,
665                    self.config.lr_weights,
666                    surprise_scale,
667                );
668
669                // Update rezero_alpha: dL/d(alpha) = delta · tanh_out
670                let grad_alpha: f64 = L::vec_dot(&effective_delta, tanh_out);
671                self.rezero_alpha[alpha_idx] -= effective_lr * grad_alpha;
672
673                // Propagated delta = nonlinear path + skip path (identity or projection)
674                if let Some(ref mut proj) = self.skip_projections[alpha_idx] {
675                    // Projection path: W_proj^T × delta
676                    let proj_t = L::mat_transpose(proj);
677                    let skip_delta = L::mat_vec_mul(&proj_t, &effective_delta);
678                    // Update projection: W_proj -= lr × outer(delta, layer_input)
679                    let dw_proj = L::outer_product(&effective_delta, layer_input);
680                    L::mat_scale_add(proj, &dw_proj, -effective_lr);
681                    bp_delta = L::vec_add(&propagated, &skip_delta);
682                } else {
683                    // Identity path: + delta
684                    bp_delta = L::vec_add(&propagated, &effective_delta);
685                }
686            } else {
687                // Standard layer: use hidden_states[i] as output
688                let layer_output = &infer_result.hidden_states[i];
689                bp_delta = self.layers[i].backward(
690                    layer_input,
691                    layer_output,
692                    &effective_delta,
693                    self.config.lr_weights,
694                    surprise_scale,
695                );
696            }
697        }
698    }
699
700    /// Extracts a serializable snapshot of current weights.
701    ///
702    /// Converts generic layers and skip projections to CPU-backed types.
703    pub fn to_weights(&self) -> crate::serializer::PcActorWeights {
704        let cpu_layers: Vec<Layer<CpuLinAlg>> = self
705            .layers
706            .iter()
707            .map(|layer| {
708                let rows = L::mat_rows(&layer.weights);
709                let cols = L::mat_cols(&layer.weights);
710                let mut cpu_weights = crate::matrix::Matrix::zeros(rows, cols);
711                for r in 0..rows {
712                    for c in 0..cols {
713                        cpu_weights.set(r, c, L::mat_get(&layer.weights, r, c));
714                    }
715                }
716                let bias_data = L::vec_to_vec(&layer.bias);
717                Layer {
718                    weights: cpu_weights,
719                    bias: bias_data,
720                    activation: layer.activation,
721                }
722            })
723            .collect();
724        let cpu_projs: Vec<Option<crate::matrix::Matrix>> = self
725            .skip_projections
726            .iter()
727            .map(|opt| {
728                opt.as_ref().map(|m| {
729                    let rows = L::mat_rows(m);
730                    let cols = L::mat_cols(m);
731                    let mut cpu_m = crate::matrix::Matrix::zeros(rows, cols);
732                    for r in 0..rows {
733                        for c in 0..cols {
734                            cpu_m.set(r, c, L::mat_get(m, r, c));
735                        }
736                    }
737                    cpu_m
738                })
739            })
740            .collect();
741        crate::serializer::PcActorWeights {
742            layers: cpu_layers,
743            rezero_alpha: self.rezero_alpha.clone(),
744            skip_projections: cpu_projs,
745        }
746    }
747
748    /// Restores an actor from saved weights without requiring an RNG.
749    ///
750    /// Converts CPU-backed weight snapshots to the target backend `L`.
751    pub fn from_weights(config: PcActorConfig, weights: crate::serializer::PcActorWeights) -> Self {
752        let layers: Vec<Layer<L>> = weights
753            .layers
754            .into_iter()
755            .map(|cpu_layer| {
756                let rows = cpu_layer.weights.rows;
757                let cols = cpu_layer.weights.cols;
758                let mut mat = L::zeros_mat(rows, cols);
759                for r in 0..rows {
760                    for c in 0..cols {
761                        L::mat_set(&mut mat, r, c, cpu_layer.weights.get(r, c));
762                    }
763                }
764                let bias = L::vec_from_slice(&cpu_layer.bias);
765                Layer {
766                    weights: mat,
767                    bias,
768                    activation: cpu_layer.activation,
769                }
770            })
771            .collect();
772        let skip_projections: Vec<Option<L::Matrix>> = weights
773            .skip_projections
774            .into_iter()
775            .map(|opt| {
776                opt.map(|cpu_m| {
777                    let rows = cpu_m.rows;
778                    let cols = cpu_m.cols;
779                    let mut mat = L::zeros_mat(rows, cols);
780                    for r in 0..rows {
781                        for c in 0..cols {
782                            L::mat_set(&mut mat, r, c, cpu_m.get(r, c));
783                        }
784                    }
785                    mat
786                })
787            })
788            .collect();
789        Self {
790            layers,
791            config,
792            rezero_alpha: weights.rezero_alpha,
793            skip_projections,
794        }
795    }
796}
797
798#[cfg(test)]
799mod tests {
800    use super::*;
801    use crate::activation::Activation;
802    use crate::layer::LayerDef;
803    use crate::matrix::WEIGHT_CLIP;
804    use rand::rngs::StdRng;
805    use rand::SeedableRng;
806
807    fn make_rng() -> StdRng {
808        StdRng::seed_from_u64(42)
809    }
810
811    fn default_config() -> PcActorConfig {
812        PcActorConfig {
813            input_size: 9,
814            hidden_layers: vec![LayerDef {
815                size: 18,
816                activation: Activation::Tanh,
817            }],
818            output_size: 9,
819            output_activation: Activation::Tanh,
820            alpha: 0.1,
821            tol: 0.01,
822            min_steps: 1,
823            max_steps: 20,
824            lr_weights: 0.01,
825            synchronous: true,
826            temperature: 1.0,
827            local_lambda: 1.0,
828            residual: false,
829            rezero_init: 0.001,
830        }
831    }
832
833    fn two_hidden_config() -> PcActorConfig {
834        PcActorConfig {
835            hidden_layers: vec![
836                LayerDef {
837                    size: 18,
838                    activation: Activation::Tanh,
839                },
840                LayerDef {
841                    size: 12,
842                    activation: Activation::Tanh,
843                },
844            ],
845            ..default_config()
846        }
847    }
848
849    // ── Inference Tests ──────────────────────────────────────────────
850
851    #[test]
852    fn test_infer_converges_on_zero_board() {
853        let mut rng = make_rng();
854        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
855        let result = actor.infer(&[0.0; 9]);
856        // Should complete without panic; all finite
857        for &v in &result.y_conv {
858            assert!(v.is_finite());
859        }
860    }
861
862    #[test]
863    fn test_infer_steps_used_at_least_min_steps() {
864        let mut rng = make_rng();
865        let config = PcActorConfig {
866            min_steps: 3,
867            ..default_config()
868        };
869        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
870        let result = actor.infer(&[0.0; 9]);
871        assert!(result.steps_used >= 3);
872    }
873
874    #[test]
875    fn test_infer_alpha_zero_does_not_converge() {
876        let mut rng = make_rng();
877        let config = PcActorConfig {
878            alpha: 0.0,
879            ..default_config()
880        };
881        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
882        let result = actor.infer(&[0.0; 9]);
883        assert!(!result.converged);
884        assert_eq!(result.steps_used, 20);
885    }
886
887    #[test]
888    fn test_infer_does_not_modify_weights() {
889        let mut rng = make_rng();
890        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
891        let weights_before: Vec<Vec<f64>> = actor
892            .layers
893            .iter()
894            .map(|l| l.weights.data.clone())
895            .collect();
896        let _ = actor.infer(&[0.0; 9]);
897        for (i, layer) in actor.layers.iter().enumerate() {
898            assert_eq!(layer.weights.data, weights_before[i]);
899        }
900    }
901
902    #[test]
903    fn test_infer_latent_size_single_hidden() {
904        let mut rng = make_rng();
905        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
906        let result = actor.infer(&[0.0; 9]);
907        assert_eq!(result.latent_concat.len(), 18);
908    }
909
910    #[test]
911    fn test_infer_latent_size_two_hidden() {
912        let mut rng = make_rng();
913        let actor: PcActor = PcActor::new(two_hidden_config(), &mut rng).unwrap();
914        let result = actor.infer(&[0.0; 9]);
915        assert_eq!(result.latent_concat.len(), 30);
916    }
917
918    #[test]
919    fn test_infer_latent_size_matches_latent_size_method() {
920        let mut rng = make_rng();
921        let actor: PcActor = PcActor::new(two_hidden_config(), &mut rng).unwrap();
922        let result = actor.infer(&[0.0; 9]);
923        assert_eq!(result.latent_concat.len(), actor.latent_size());
924    }
925
926    #[test]
927    fn test_infer_y_conv_length_equals_output_size() {
928        let mut rng = make_rng();
929        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
930        let result = actor.infer(&[0.0; 9]);
931        assert_eq!(result.y_conv.len(), 9);
932    }
933
934    #[test]
935    fn test_infer_hidden_states_count_matches_hidden_layers() {
936        let mut rng = make_rng();
937        let actor: PcActor = PcActor::new(two_hidden_config(), &mut rng).unwrap();
938        let result = actor.infer(&[0.0; 9]);
939        assert_eq!(result.hidden_states.len(), 2);
940    }
941
942    #[test]
943    fn test_infer_all_outputs_finite() {
944        let mut rng = make_rng();
945        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
946        let result = actor.infer(&[1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5]);
947        for &v in &result.y_conv {
948            assert!(v.is_finite());
949        }
950        for &v in &result.latent_concat {
951            assert!(v.is_finite());
952        }
953        assert!(result.surprise_score.is_finite());
954    }
955
956    #[test]
957    fn test_infer_surprise_score_nonnegative() {
958        let mut rng = make_rng();
959        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
960        let result = actor.infer(&[0.0; 9]);
961        assert!(result.surprise_score >= 0.0);
962    }
963
964    #[test]
965    fn test_infer_synchronous_and_inplace_both_converge() {
966        let mut rng = make_rng();
967        let sync_actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
968        let mut rng2 = make_rng();
969        let inplace_config = PcActorConfig {
970            synchronous: false,
971            ..default_config()
972        };
973        let inplace_actor: PcActor = PcActor::new(inplace_config, &mut rng2).unwrap();
974        let sync_result = sync_actor.infer(&[0.0; 9]);
975        let inplace_result = inplace_actor.infer(&[0.0; 9]);
976        // Both should complete without panic; at least one should converge or use all steps
977        assert!(sync_result.steps_used > 0);
978        assert!(inplace_result.steps_used > 0);
979    }
980
981    #[test]
982    fn test_infer_synchronous_produces_different_result_than_inplace() {
983        let mut rng = make_rng();
984        let config = PcActorConfig {
985            hidden_layers: vec![
986                LayerDef {
987                    size: 18,
988                    activation: Activation::Tanh,
989                },
990                LayerDef {
991                    size: 12,
992                    activation: Activation::Tanh,
993                },
994            ],
995            alpha: 0.3,
996            tol: 1e-15,
997            min_steps: 1,
998            max_steps: 3,
999            ..default_config()
1000        };
1001        let sync_actor: PcActor = PcActor::new(config.clone(), &mut rng).unwrap();
1002        let mut rng2 = make_rng();
1003        let inplace_config = PcActorConfig {
1004            synchronous: false,
1005            ..config
1006        };
1007        let inplace_actor: PcActor = PcActor::new(inplace_config, &mut rng2).unwrap();
1008        let input = [1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1009        let sync_result = sync_actor.infer(&input);
1010        let inplace_result = inplace_actor.infer(&input);
1011        // Different update orders should produce different hidden representations
1012        let differs = sync_result
1013            .latent_concat
1014            .iter()
1015            .zip(inplace_result.latent_concat.iter())
1016            .any(|(a, b)| (a - b).abs() > 1e-12);
1017        assert!(
1018            differs,
1019            "Synchronous and in-place should produce different results"
1020        );
1021    }
1022
1023    #[test]
1024    #[should_panic(expected = "input size")]
1025    fn test_infer_panics_wrong_input_length() {
1026        let mut rng = make_rng();
1027        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1028        let _ = actor.infer(&[0.0; 5]);
1029    }
1030
1031    // ── Action Selection Tests ───────────────────────────────────────
1032
1033    #[test]
1034    fn test_select_action_training_always_in_valid() {
1035        let mut rng = make_rng();
1036        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1037        let logits = vec![0.1, -0.2, 0.5, -0.1, 0.3, 0.0, -0.3, 0.2, 0.4];
1038        let valid = vec![0, 2, 4, 6, 8];
1039        for _ in 0..20 {
1040            let action = actor.select_action(&logits, &valid, SelectionMode::Training, &mut rng);
1041            assert!(valid.contains(&action));
1042        }
1043    }
1044
1045    #[test]
1046    fn test_select_action_play_mode_deterministic() {
1047        let mut rng1 = StdRng::seed_from_u64(1);
1048        let mut rng2 = StdRng::seed_from_u64(99);
1049        let mut rng_init = make_rng();
1050        let actor: PcActor = PcActor::new(default_config(), &mut rng_init).unwrap();
1051        let logits = vec![0.1, -0.2, 0.5, -0.1, 0.3, 0.0, -0.3, 0.2, 0.4];
1052        let valid = vec![0, 2, 4, 6, 8];
1053        let a1 = actor.select_action(&logits, &valid, SelectionMode::Play, &mut rng1);
1054        let a2 = actor.select_action(&logits, &valid, SelectionMode::Play, &mut rng2);
1055        assert_eq!(a1, a2, "Play mode should be deterministic");
1056    }
1057
1058    #[test]
1059    fn test_select_action_temperature_gt_one_more_uniform() {
1060        let mut rng = make_rng();
1061        let hot_config = PcActorConfig {
1062            temperature: 5.0,
1063            ..default_config()
1064        };
1065        let actor: PcActor = PcActor::new(hot_config, &mut rng).unwrap();
1066        // With high temperature, sampling should visit more actions
1067        let logits = vec![10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
1068        let valid: Vec<usize> = (0..9).collect();
1069        let mut seen = std::collections::HashSet::new();
1070        let mut rng2 = StdRng::seed_from_u64(123);
1071        for _ in 0..100 {
1072            let a = actor.select_action(&logits, &valid, SelectionMode::Training, &mut rng2);
1073            seen.insert(a);
1074        }
1075        assert!(seen.len() > 1, "High temperature should explore more");
1076    }
1077
1078    #[test]
1079    #[should_panic]
1080    fn test_select_action_empty_valid_panics() {
1081        let mut rng = make_rng();
1082        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1083        let logits = vec![0.1; 9];
1084        let _ = actor.select_action(&logits, &[], SelectionMode::Training, &mut rng);
1085    }
1086
1087    // ── Weight Update Tests ──────────────────────────────────────────
1088
1089    #[test]
1090    fn test_update_weights_changes_first_layer() {
1091        let mut rng = make_rng();
1092        let mut actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1093        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1094        let infer_result = actor.infer(&input);
1095        let weights_before = actor.layers[0].weights.data.clone();
1096        let delta = vec![0.1; 9];
1097        actor.update_weights(&delta, &infer_result, &input, 1.0);
1098        assert_ne!(actor.layers[0].weights.data, weights_before);
1099    }
1100
1101    #[test]
1102    fn test_update_weights_clips_all_layers() {
1103        let mut rng = make_rng();
1104        let mut actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1105        let input = vec![1.0; 9];
1106        let infer_result = actor.infer(&input);
1107        let delta = vec![1e6; 9];
1108        actor.update_weights(&delta, &infer_result, &input, 1.0);
1109        for layer in &actor.layers {
1110            for &w in &layer.weights.data {
1111                assert!(
1112                    w.abs() <= WEIGHT_CLIP + 1e-12,
1113                    "Weight {w} exceeds WEIGHT_CLIP"
1114                );
1115            }
1116        }
1117    }
1118
1119    #[test]
1120    fn test_update_weights_two_hidden_changes_both_layers() {
1121        let mut rng = make_rng();
1122        let mut actor: PcActor = PcActor::new(two_hidden_config(), &mut rng).unwrap();
1123        let input = vec![0.5; 9];
1124        let infer_result = actor.infer(&input);
1125        let w0_before = actor.layers[0].weights.data.clone();
1126        let w1_before = actor.layers[1].weights.data.clone();
1127        let delta = vec![0.1; 9];
1128        actor.update_weights(&delta, &infer_result, &input, 1.0);
1129        assert_ne!(
1130            actor.layers[0].weights.data, w0_before,
1131            "Layer 0 should change"
1132        );
1133        assert_ne!(
1134            actor.layers[1].weights.data, w1_before,
1135            "Layer 1 should change"
1136        );
1137    }
1138
1139    #[test]
1140    #[should_panic(expected = "input size")]
1141    fn test_update_weights_panics_wrong_x_size() {
1142        let mut rng = make_rng();
1143        let mut actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1144        let input = vec![0.0; 9];
1145        let infer_result = actor.infer(&input);
1146        let delta = vec![0.1; 9];
1147        actor.update_weights(&delta, &infer_result, &[0.0; 5], 1.0);
1148    }
1149
1150    // ── Zero Hidden Layers Test ─────────────────────────────────
1151
1152    #[test]
1153    fn test_infer_zero_hidden_layers_produces_finite_output() {
1154        let mut rng = make_rng();
1155        let config = PcActorConfig {
1156            hidden_layers: vec![],
1157            ..default_config()
1158        };
1159        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1160        let result = actor.infer(&[0.5; 9]);
1161        assert_eq!(result.y_conv.len(), 9);
1162        assert!(result.y_conv.iter().all(|v| v.is_finite()));
1163        assert!(result.latent_concat.is_empty());
1164        assert!(result.hidden_states.is_empty());
1165    }
1166
1167    // ── Config Validation Tests ─────────────────────────────────
1168
1169    #[test]
1170    fn test_new_zero_input_size_returns_error() {
1171        let mut rng = make_rng();
1172        let config = PcActorConfig {
1173            input_size: 0,
1174            ..default_config()
1175        };
1176        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
1177        assert!(result.is_err());
1178        let err = result.unwrap_err();
1179        assert!(matches!(err, crate::error::PcError::ConfigValidation(_)));
1180    }
1181
1182    #[test]
1183    fn test_new_zero_output_size_returns_error() {
1184        let mut rng = make_rng();
1185        let config = PcActorConfig {
1186            output_size: 0,
1187            ..default_config()
1188        };
1189        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
1190        assert!(result.is_err());
1191    }
1192
1193    #[test]
1194    fn test_new_zero_temperature_returns_error() {
1195        let mut rng = make_rng();
1196        let config = PcActorConfig {
1197            temperature: 0.0,
1198            ..default_config()
1199        };
1200        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
1201        assert!(result.is_err());
1202    }
1203
1204    #[test]
1205    fn test_new_negative_temperature_returns_error() {
1206        let mut rng = make_rng();
1207        let config = PcActorConfig {
1208            temperature: -1.0,
1209            ..default_config()
1210        };
1211        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
1212        assert!(result.is_err());
1213    }
1214
1215    // ── Residual / ReZero Config Tests ────────────────────────
1216
1217    #[test]
1218    fn test_default_config_residual_false() {
1219        let config = default_config();
1220        assert!(!config.residual);
1221    }
1222
1223    #[test]
1224    fn test_default_config_rezero_init() {
1225        let config = default_config();
1226        assert!((config.rezero_init - 0.001).abs() < 1e-12);
1227    }
1228
1229    #[test]
1230    fn test_new_negative_rezero_init_returns_error() {
1231        let mut rng = make_rng();
1232        let config = PcActorConfig {
1233            residual: true,
1234            rezero_init: -0.1,
1235            ..default_config()
1236        };
1237        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
1238        assert!(result.is_err());
1239    }
1240
1241    #[test]
1242    fn test_residual_mixed_sizes_accepted() {
1243        let mut rng = make_rng();
1244        let config = PcActorConfig {
1245            residual: true,
1246            hidden_layers: vec![
1247                LayerDef {
1248                    size: 27,
1249                    activation: Activation::Tanh,
1250                },
1251                LayerDef {
1252                    size: 18,
1253                    activation: Activation::Tanh,
1254                },
1255            ],
1256            ..default_config()
1257        };
1258        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
1259        assert!(result.is_ok());
1260    }
1261
1262    #[test]
1263    fn test_residual_mixed_sizes_all_skip() {
1264        // [27, 27, 18]: ALL layers i>=1 get skip — identity for 27→27, projection for 27→18
1265        let mut rng = make_rng();
1266        let config = PcActorConfig {
1267            residual: true,
1268            hidden_layers: vec![
1269                LayerDef {
1270                    size: 27,
1271                    activation: Activation::Tanh,
1272                },
1273                LayerDef {
1274                    size: 27,
1275                    activation: Activation::Tanh,
1276                },
1277                LayerDef {
1278                    size: 18,
1279                    activation: Activation::Tanh,
1280                },
1281            ],
1282            ..default_config()
1283        };
1284        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1285        // 2 skips: layer 1 (identity) + layer 2 (projection)
1286        assert_eq!(actor.rezero_alpha.len(), 2);
1287    }
1288
1289    #[test]
1290    fn test_residual_heterogeneous_has_projection() {
1291        // [27, 18]: different sizes → projection matrix created
1292        let mut rng = make_rng();
1293        let config = PcActorConfig {
1294            residual: true,
1295            hidden_layers: vec![
1296                LayerDef {
1297                    size: 27,
1298                    activation: Activation::Tanh,
1299                },
1300                LayerDef {
1301                    size: 18,
1302                    activation: Activation::Tanh,
1303                },
1304            ],
1305            ..default_config()
1306        };
1307        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1308        assert_eq!(actor.rezero_alpha.len(), 1);
1309        assert_eq!(actor.skip_projections.len(), 1);
1310        assert!(actor.skip_projections[0].is_some());
1311        let proj = actor.skip_projections[0].as_ref().unwrap();
1312        assert_eq!(proj.rows, 18); // output dim
1313        assert_eq!(proj.cols, 27); // input dim
1314    }
1315
1316    #[test]
1317    fn test_residual_homogeneous_no_projection() {
1318        // [27, 27]: same sizes → no projection needed
1319        let mut rng = make_rng();
1320        let actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1321        assert_eq!(actor.skip_projections.len(), 1);
1322        assert!(actor.skip_projections[0].is_none());
1323    }
1324
1325    #[test]
1326    fn test_residual_mixed_sizes_infer_finite() {
1327        let mut rng = make_rng();
1328        let config = PcActorConfig {
1329            residual: true,
1330            hidden_layers: vec![
1331                LayerDef {
1332                    size: 27,
1333                    activation: Activation::Tanh,
1334                },
1335                LayerDef {
1336                    size: 27,
1337                    activation: Activation::Tanh,
1338                },
1339                LayerDef {
1340                    size: 18,
1341                    activation: Activation::Tanh,
1342                },
1343            ],
1344            ..default_config()
1345        };
1346        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1347        let result = actor.infer(&[0.5; 9]);
1348        for &v in &result.y_conv {
1349            assert!(v.is_finite());
1350        }
1351        assert_eq!(result.hidden_states.len(), 3);
1352        assert_eq!(result.latent_concat.len(), 27 + 27 + 18);
1353    }
1354
1355    #[test]
1356    fn test_residual_same_size_hidden_layers_accepted() {
1357        let mut rng = make_rng();
1358        let config = PcActorConfig {
1359            residual: true,
1360            hidden_layers: vec![
1361                LayerDef {
1362                    size: 27,
1363                    activation: Activation::Tanh,
1364                },
1365                LayerDef {
1366                    size: 27,
1367                    activation: Activation::Tanh,
1368                },
1369            ],
1370            ..default_config()
1371        };
1372        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
1373        assert!(result.is_ok());
1374    }
1375
1376    fn residual_two_hidden_config() -> PcActorConfig {
1377        PcActorConfig {
1378            residual: true,
1379            hidden_layers: vec![
1380                LayerDef {
1381                    size: 27,
1382                    activation: Activation::Tanh,
1383                },
1384                LayerDef {
1385                    size: 27,
1386                    activation: Activation::Tanh,
1387                },
1388            ],
1389            ..default_config()
1390        }
1391    }
1392
1393    #[test]
1394    fn test_non_residual_actor_empty_rezero_alpha() {
1395        let mut rng = make_rng();
1396        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1397        assert!(actor.rezero_alpha.is_empty());
1398    }
1399
1400    #[test]
1401    fn test_residual_two_hidden_one_rezero_alpha() {
1402        let mut rng = make_rng();
1403        let actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1404        assert_eq!(actor.rezero_alpha.len(), 1);
1405    }
1406
1407    #[test]
1408    fn test_residual_three_hidden_two_rezero_alpha() {
1409        let mut rng = make_rng();
1410        let config = PcActorConfig {
1411            residual: true,
1412            hidden_layers: vec![
1413                LayerDef {
1414                    size: 27,
1415                    activation: Activation::Tanh,
1416                },
1417                LayerDef {
1418                    size: 27,
1419                    activation: Activation::Tanh,
1420                },
1421                LayerDef {
1422                    size: 27,
1423                    activation: Activation::Tanh,
1424                },
1425            ],
1426            ..default_config()
1427        };
1428        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1429        assert_eq!(actor.rezero_alpha.len(), 2);
1430    }
1431
1432    #[test]
1433    fn test_rezero_alpha_initialized_to_rezero_init() {
1434        let mut rng = make_rng();
1435        let config = PcActorConfig {
1436            rezero_init: 0.005,
1437            ..residual_two_hidden_config()
1438        };
1439        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1440        assert!((actor.rezero_alpha[0] - 0.005).abs() < 1e-12);
1441    }
1442
1443    #[test]
1444    fn test_residual_single_hidden_zero_rezero_alpha() {
1445        let mut rng = make_rng();
1446        let config = PcActorConfig {
1447            residual: true,
1448            ..default_config()
1449        };
1450        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1451        assert!(actor.rezero_alpha.is_empty());
1452    }
1453
1454    #[test]
1455    fn test_residual_single_hidden_accepted() {
1456        let mut rng = make_rng();
1457        let config = PcActorConfig {
1458            residual: true,
1459            ..default_config()
1460        };
1461        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
1462        assert!(result.is_ok());
1463    }
1464
1465    // ── Local Learning (PC-based weight updates) Tests ──────────
1466
1467    // ── Residual Inference Tests ──────────────────────────────
1468
1469    #[test]
1470    fn test_residual_false_identical_to_non_residual() {
1471        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1472        let mut rng1 = make_rng();
1473        let actor1: PcActor = PcActor::new(two_hidden_config(), &mut rng1).unwrap();
1474        let result1 = actor1.infer(&input);
1475
1476        let mut rng2 = make_rng();
1477        let config2 = PcActorConfig {
1478            residual: false,
1479            ..two_hidden_config()
1480        };
1481        let actor2: PcActor = PcActor::new(config2, &mut rng2).unwrap();
1482        let result2 = actor2.infer(&input);
1483
1484        for (a, b) in result1.y_conv.iter().zip(result2.y_conv.iter()) {
1485            assert!((a - b).abs() < 1e-12);
1486        }
1487    }
1488
1489    #[test]
1490    fn test_residual_rezero_zero_second_hidden_near_identity() {
1491        let mut rng = make_rng();
1492        let config = PcActorConfig {
1493            rezero_init: 0.0,
1494            alpha: 0.0,
1495            ..residual_two_hidden_config()
1496        };
1497        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1498        let result = actor.infer(&[0.5; 9]);
1499        let h0 = &result.hidden_states[0];
1500        let h1 = &result.hidden_states[1];
1501        for (a, b) in h0.iter().zip(h1.iter()) {
1502            assert!(
1503                (a - b).abs() < 1e-12,
1504                "With rezero_init=0, h[1] should equal h[0]"
1505            );
1506        }
1507    }
1508
1509    #[test]
1510    fn test_residual_infer_all_outputs_finite() {
1511        let mut rng = make_rng();
1512        let actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1513        let result = actor.infer(&[0.5; 9]);
1514        for &v in &result.y_conv {
1515            assert!(v.is_finite());
1516        }
1517        for &v in &result.latent_concat {
1518            assert!(v.is_finite());
1519        }
1520        assert!(result.surprise_score.is_finite());
1521    }
1522
1523    #[test]
1524    fn test_residual_latent_concat_size() {
1525        let mut rng = make_rng();
1526        let actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1527        let result = actor.infer(&[0.5; 9]);
1528        assert_eq!(result.latent_concat.len(), 54); // 27 + 27
1529    }
1530
1531    #[test]
1532    fn test_residual_pc_loop_completes() {
1533        let mut rng = make_rng();
1534        let config = PcActorConfig {
1535            alpha: 0.03,
1536            max_steps: 5,
1537            ..residual_two_hidden_config()
1538        };
1539        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1540        let result = actor.infer(&[0.5; 9]);
1541        assert!(result.steps_used > 0);
1542        assert!(result.steps_used <= 5);
1543    }
1544
1545    #[test]
1546    fn test_residual_hidden_states_count() {
1547        let mut rng = make_rng();
1548        let actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1549        let result = actor.infer(&[0.5; 9]);
1550        assert_eq!(result.hidden_states.len(), 2);
1551    }
1552
1553    #[test]
1554    fn test_residual_infer_does_not_modify_weights() {
1555        let mut rng = make_rng();
1556        let actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1557        let weights_before: Vec<Vec<f64>> = actor
1558            .layers
1559            .iter()
1560            .map(|l| l.weights.data.clone())
1561            .collect();
1562        let alpha_before = actor.rezero_alpha.clone();
1563        let _ = actor.infer(&[0.5; 9]);
1564        for (i, layer) in actor.layers.iter().enumerate() {
1565            assert_eq!(layer.weights.data, weights_before[i]);
1566        }
1567        assert_eq!(actor.rezero_alpha, alpha_before);
1568    }
1569
1570    #[test]
1571    fn test_residual_three_hidden_infer_finite() {
1572        let mut rng = make_rng();
1573        let config = PcActorConfig {
1574            residual: true,
1575            hidden_layers: vec![
1576                LayerDef {
1577                    size: 27,
1578                    activation: Activation::Tanh,
1579                },
1580                LayerDef {
1581                    size: 27,
1582                    activation: Activation::Tanh,
1583                },
1584                LayerDef {
1585                    size: 27,
1586                    activation: Activation::Tanh,
1587                },
1588            ],
1589            ..default_config()
1590        };
1591        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1592        let result = actor.infer(&[0.5; 9]);
1593        for &v in &result.y_conv {
1594            assert!(v.is_finite());
1595        }
1596    }
1597
1598    #[test]
1599    fn test_residual_tanh_components_populated() {
1600        let mut rng = make_rng();
1601        let actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1602        let result = actor.infer(&[0.5; 9]);
1603        assert_eq!(result.tanh_components.len(), 2);
1604        assert!(result.tanh_components[0].is_none()); // layer 0: no skip
1605        assert!(result.tanh_components[1].is_some()); // layer 1: has skip
1606        assert_eq!(result.tanh_components[1].as_ref().unwrap().len(), 27);
1607    }
1608
1609    #[test]
1610    fn test_residual_pc_prediction_uses_tanh_component_not_full_state() {
1611        // With rezero_init=1.0, h[1] = tanh_out + h[0] (significantly different
1612        // from tanh_out alone). If PC prediction uses h[1] instead of tanh_out,
1613        // the surprise score and convergence will differ.
1614        // Two runs with same weights: one with alpha=0 (no PC), one with alpha>0.
1615        // The PC loop should converge meaningfully (surprise decreases).
1616        let mut rng = make_rng();
1617        let config = PcActorConfig {
1618            rezero_init: 1.0,
1619            alpha: 0.1,
1620            max_steps: 20,
1621            tol: 0.001,
1622            min_steps: 1,
1623            ..residual_two_hidden_config()
1624        };
1625        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1626        let result = actor.infer(&[1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5]);
1627        // With proper PC predictions, surprise should be finite and non-negative
1628        assert!(result.surprise_score.is_finite());
1629        assert!(result.surprise_score >= 0.0);
1630        // Prediction errors should all be finite
1631        for errors in &result.prediction_errors {
1632            for &e in errors {
1633                assert!(e.is_finite(), "PC prediction error not finite: {e}");
1634            }
1635        }
1636    }
1637
1638    // ── Residual Backward Tests ────────────────────────────────
1639
1640    #[test]
1641    fn test_residual_false_update_identical_to_non_residual() {
1642        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1643        let delta = vec![0.1; 9];
1644
1645        let mut rng1 = make_rng();
1646        let mut actor1: PcActor = PcActor::new(two_hidden_config(), &mut rng1).unwrap();
1647        let infer1 = actor1.infer(&input);
1648        actor1.update_weights(&delta, &infer1, &input, 1.0);
1649
1650        let mut rng2 = make_rng();
1651        let config2 = PcActorConfig {
1652            residual: false,
1653            ..two_hidden_config()
1654        };
1655        let mut actor2: PcActor = PcActor::new(config2, &mut rng2).unwrap();
1656        let infer2 = actor2.infer(&input);
1657        actor2.update_weights(&delta, &infer2, &input, 1.0);
1658
1659        for i in 0..actor1.layers.len() {
1660            assert_eq!(actor1.layers[i].weights.data, actor2.layers[i].weights.data);
1661        }
1662    }
1663
1664    #[test]
1665    fn test_residual_update_changes_all_layer_weights() {
1666        let mut rng = make_rng();
1667        let mut actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1668        let input = vec![0.5; 9];
1669        let infer_result = actor.infer(&input);
1670        let w0 = actor.layers[0].weights.data.clone();
1671        let w1 = actor.layers[1].weights.data.clone();
1672        let w2 = actor.layers[2].weights.data.clone();
1673        actor.update_weights(&[0.1; 9], &infer_result, &input, 1.0);
1674        assert_ne!(actor.layers[0].weights.data, w0, "Layer 0 should change");
1675        assert_ne!(actor.layers[1].weights.data, w1, "Layer 1 should change");
1676        assert_ne!(
1677            actor.layers[2].weights.data, w2,
1678            "Output layer should change"
1679        );
1680    }
1681
1682    #[test]
1683    fn test_residual_update_changes_rezero_alpha() {
1684        let mut rng = make_rng();
1685        let mut actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1686        let input = vec![0.5; 9];
1687        let infer_result = actor.infer(&input);
1688        let alpha_before = actor.rezero_alpha.clone();
1689        actor.update_weights(&[0.1; 9], &infer_result, &input, 1.0);
1690        assert_ne!(
1691            actor.rezero_alpha, alpha_before,
1692            "rezero_alpha should be updated by backprop"
1693        );
1694    }
1695
1696    #[test]
1697    fn test_residual_update_clips_weights() {
1698        let mut rng = make_rng();
1699        let mut actor: PcActor = PcActor::new(residual_two_hidden_config(), &mut rng).unwrap();
1700        let input = vec![1.0; 9];
1701        let infer_result = actor.infer(&input);
1702        actor.update_weights(&[1e6; 9], &infer_result, &input, 1.0);
1703        for layer in &actor.layers {
1704            for &w in &layer.weights.data {
1705                assert!(
1706                    w.abs() <= WEIGHT_CLIP + 1e-12,
1707                    "Weight {w} exceeds WEIGHT_CLIP"
1708                );
1709            }
1710        }
1711    }
1712
1713    #[test]
1714    fn test_residual_gradient_stronger_than_non_residual() {
1715        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1716        let delta = vec![0.1; 9];
1717
1718        // Non-residual 2 hidden layers (27, 27)
1719        let mut rng1 = make_rng();
1720        let config1 = PcActorConfig {
1721            hidden_layers: vec![
1722                LayerDef {
1723                    size: 27,
1724                    activation: Activation::Tanh,
1725                },
1726                LayerDef {
1727                    size: 27,
1728                    activation: Activation::Tanh,
1729                },
1730            ],
1731            ..default_config()
1732        };
1733        let mut actor1: PcActor = PcActor::new(config1, &mut rng1).unwrap();
1734        let w0_before1 = actor1.layers[0].weights.data.clone();
1735        let infer1 = actor1.infer(&input);
1736        actor1.update_weights(&delta, &infer1, &input, 1.0);
1737        let change1: f64 = actor1.layers[0]
1738            .weights
1739            .data
1740            .iter()
1741            .zip(w0_before1.iter())
1742            .map(|(a, b)| (a - b).abs())
1743            .sum();
1744
1745        // Residual 2 hidden layers (27, 27) with rezero_init=1.0
1746        let mut rng2 = make_rng();
1747        let config2 = PcActorConfig {
1748            rezero_init: 1.0,
1749            ..residual_two_hidden_config()
1750        };
1751        let mut actor2: PcActor = PcActor::new(config2, &mut rng2).unwrap();
1752        let w0_before2 = actor2.layers[0].weights.data.clone();
1753        let infer2 = actor2.infer(&input);
1754        actor2.update_weights(&delta, &infer2, &input, 1.0);
1755        let change2: f64 = actor2.layers[0]
1756            .weights
1757            .data
1758            .iter()
1759            .zip(w0_before2.iter())
1760            .map(|(a, b)| (a - b).abs())
1761            .sum();
1762
1763        assert!(
1764            change2 > change1,
1765            "Residual should propagate stronger gradient to layer 0: residual={change2:.6}, non-residual={change1:.6}"
1766        );
1767    }
1768
1769    #[test]
1770    fn test_residual_hybrid_lambda_works() {
1771        let mut rng = make_rng();
1772        let config = PcActorConfig {
1773            local_lambda: 0.99,
1774            ..residual_two_hidden_config()
1775        };
1776        let mut actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1777        let input = vec![0.5; 9];
1778        let infer_result = actor.infer(&input);
1779        let w0_before = actor.layers[0].weights.data.clone();
1780        actor.update_weights(&[0.1; 9], &infer_result, &input, 1.0);
1781        assert_ne!(actor.layers[0].weights.data, w0_before);
1782    }
1783
1784    fn local_learning_config() -> PcActorConfig {
1785        PcActorConfig {
1786            local_lambda: 0.0,
1787            ..default_config()
1788        }
1789    }
1790
1791    #[test]
1792    fn test_infer_prediction_errors_count_matches_hidden_layers() {
1793        let mut rng = make_rng();
1794        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1795        let result = actor.infer(&[0.0; 9]);
1796        assert_eq!(result.prediction_errors.len(), 1);
1797    }
1798
1799    #[test]
1800    fn test_infer_prediction_errors_two_hidden() {
1801        let mut rng = make_rng();
1802        let actor: PcActor = PcActor::new(two_hidden_config(), &mut rng).unwrap();
1803        let result = actor.infer(&[0.0; 9]);
1804        assert_eq!(result.prediction_errors.len(), 2);
1805    }
1806
1807    #[test]
1808    fn test_infer_prediction_errors_zero_hidden_is_empty() {
1809        let mut rng = make_rng();
1810        let config = PcActorConfig {
1811            hidden_layers: vec![],
1812            ..default_config()
1813        };
1814        let actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1815        let result = actor.infer(&[0.5; 9]);
1816        assert!(result.prediction_errors.is_empty());
1817    }
1818
1819    #[test]
1820    fn test_infer_prediction_errors_all_finite() {
1821        let mut rng = make_rng();
1822        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1823        let result = actor.infer(&[1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5]);
1824        for errors in &result.prediction_errors {
1825            for &e in errors {
1826                assert!(e.is_finite(), "prediction error not finite: {e}");
1827            }
1828        }
1829    }
1830
1831    #[test]
1832    fn test_infer_prediction_errors_size_matches_hidden_layer_size() {
1833        let mut rng = make_rng();
1834        let actor: PcActor = PcActor::new(default_config(), &mut rng).unwrap();
1835        let result = actor.infer(&[0.0; 9]);
1836        // default_config has one hidden layer of size 18
1837        assert_eq!(result.prediction_errors[0].len(), 18);
1838    }
1839
1840    #[test]
1841    fn test_local_learning_config_accepted() {
1842        let mut rng = make_rng();
1843        let config = local_learning_config();
1844        assert!((config.local_lambda).abs() < f64::EPSILON);
1845        let actor: Result<PcActor, _> = PcActor::new(config, &mut rng);
1846        assert!(actor.is_ok());
1847    }
1848
1849    #[test]
1850    fn test_local_learning_update_changes_weights() {
1851        let mut rng = make_rng();
1852        let mut actor: PcActor = PcActor::new(local_learning_config(), &mut rng).unwrap();
1853        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1854        let infer_result = actor.infer(&input);
1855        let weights_before = actor.layers[0].weights.data.clone();
1856        let delta = vec![0.1; 9];
1857        actor.update_weights(&delta, &infer_result, &input, 1.0);
1858        assert_ne!(actor.layers[0].weights.data, weights_before);
1859    }
1860
1861    #[test]
1862    fn test_local_learning_clips_weights() {
1863        let mut rng = make_rng();
1864        let mut actor: PcActor = PcActor::new(local_learning_config(), &mut rng).unwrap();
1865        let input = vec![1.0; 9];
1866        let infer_result = actor.infer(&input);
1867        let delta = vec![1e6; 9];
1868        actor.update_weights(&delta, &infer_result, &input, 1.0);
1869        for layer in &actor.layers {
1870            for &w in &layer.weights.data {
1871                assert!(
1872                    w.abs() <= WEIGHT_CLIP + 1e-12,
1873                    "Weight {w} exceeds WEIGHT_CLIP"
1874                );
1875            }
1876        }
1877    }
1878
1879    #[test]
1880    fn test_local_learning_two_hidden_changes_both() {
1881        let mut rng = make_rng();
1882        let config = PcActorConfig {
1883            local_lambda: 0.0,
1884            ..two_hidden_config()
1885        };
1886        let mut actor: PcActor = PcActor::new(config, &mut rng).unwrap();
1887        let input = vec![0.5; 9];
1888        let infer_result = actor.infer(&input);
1889        let w0_before = actor.layers[0].weights.data.clone();
1890        let w1_before = actor.layers[1].weights.data.clone();
1891        let delta = vec![0.1; 9];
1892        actor.update_weights(&delta, &infer_result, &input, 1.0);
1893        assert_ne!(
1894            actor.layers[0].weights.data, w0_before,
1895            "Layer 0 should change"
1896        );
1897        assert_ne!(
1898            actor.layers[1].weights.data, w1_before,
1899            "Layer 1 should change"
1900        );
1901    }
1902
1903    #[test]
1904    fn test_local_learning_differs_from_backprop() {
1905        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1906        let delta = vec![0.1; 9];
1907
1908        // Backprop actor
1909        let mut rng1 = make_rng();
1910        let mut bp_actor: PcActor = PcActor::new(default_config(), &mut rng1).unwrap();
1911        let bp_infer = bp_actor.infer(&input);
1912        bp_actor.update_weights(&delta, &bp_infer, &input, 1.0);
1913
1914        // Local learning actor (same initial weights)
1915        let mut rng2 = make_rng();
1916        let mut ll_actor: PcActor = PcActor::new(local_learning_config(), &mut rng2).unwrap();
1917        let ll_infer = ll_actor.infer(&input);
1918        ll_actor.update_weights(&delta, &ll_infer, &input, 1.0);
1919
1920        // Hidden layer weights should differ between the two approaches
1921        assert_ne!(
1922            bp_actor.layers[0].weights.data, ll_actor.layers[0].weights.data,
1923            "Local learning should produce different weight updates than backprop"
1924        );
1925    }
1926
1927    // ── Hybrid Learning (local_lambda) Tests ────────────────────
1928
1929    fn hybrid_config(lambda: f64) -> PcActorConfig {
1930        PcActorConfig {
1931            local_lambda: lambda,
1932            ..default_config()
1933        }
1934    }
1935
1936    #[test]
1937    fn test_local_lambda_one_equals_backprop() {
1938        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1939        let delta = vec![0.1; 9];
1940
1941        // Pure backprop (local_learning=false, default)
1942        let mut rng1 = make_rng();
1943        let mut bp_actor: PcActor = PcActor::new(default_config(), &mut rng1).unwrap();
1944        let bp_infer = bp_actor.infer(&input);
1945        bp_actor.update_weights(&delta, &bp_infer, &input, 1.0);
1946
1947        // lambda=1.0 should be identical to backprop
1948        let mut rng2 = make_rng();
1949        let mut lam_actor: PcActor = PcActor::new(hybrid_config(1.0), &mut rng2).unwrap();
1950        let lam_infer = lam_actor.infer(&input);
1951        lam_actor.update_weights(&delta, &lam_infer, &input, 1.0);
1952
1953        assert_eq!(
1954            bp_actor.layers[0].weights.data, lam_actor.layers[0].weights.data,
1955            "lambda=1.0 should produce identical weights to pure backprop"
1956        );
1957    }
1958
1959    #[test]
1960    fn test_local_lambda_zero_equals_local_learning() {
1961        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1962        let delta = vec![0.1; 9];
1963
1964        // Pure local (local_learning=true)
1965        let mut rng1 = make_rng();
1966        let mut ll_actor: PcActor = PcActor::new(local_learning_config(), &mut rng1).unwrap();
1967        let ll_infer = ll_actor.infer(&input);
1968        ll_actor.update_weights(&delta, &ll_infer, &input, 1.0);
1969
1970        // lambda=0.0 should be identical to pure local
1971        let mut rng2 = make_rng();
1972        let mut lam_actor: PcActor = PcActor::new(hybrid_config(0.0), &mut rng2).unwrap();
1973        let lam_infer = lam_actor.infer(&input);
1974        lam_actor.update_weights(&delta, &lam_infer, &input, 1.0);
1975
1976        assert_eq!(
1977            ll_actor.layers[0].weights.data, lam_actor.layers[0].weights.data,
1978            "lambda=0.0 should produce identical weights to pure local learning"
1979        );
1980    }
1981
1982    #[test]
1983    fn test_local_lambda_half_differs_from_both_pure_modes() {
1984        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
1985        let delta = vec![0.1; 9];
1986
1987        // Pure backprop
1988        let mut rng1 = make_rng();
1989        let mut bp_actor: PcActor = PcActor::new(default_config(), &mut rng1).unwrap();
1990        let bp_infer = bp_actor.infer(&input);
1991        bp_actor.update_weights(&delta, &bp_infer, &input, 1.0);
1992
1993        // Pure local
1994        let mut rng2 = make_rng();
1995        let mut ll_actor: PcActor = PcActor::new(local_learning_config(), &mut rng2).unwrap();
1996        let ll_infer = ll_actor.infer(&input);
1997        ll_actor.update_weights(&delta, &ll_infer, &input, 1.0);
1998
1999        // Hybrid lambda=0.5
2000        let mut rng3 = make_rng();
2001        let mut hy_actor: PcActor = PcActor::new(hybrid_config(0.5), &mut rng3).unwrap();
2002        let hy_infer = hy_actor.infer(&input);
2003        hy_actor.update_weights(&delta, &hy_infer, &input, 1.0);
2004
2005        assert_ne!(
2006            hy_actor.layers[0].weights.data, bp_actor.layers[0].weights.data,
2007            "lambda=0.5 should differ from pure backprop"
2008        );
2009        assert_ne!(
2010            hy_actor.layers[0].weights.data, ll_actor.layers[0].weights.data,
2011            "lambda=0.5 should differ from pure local"
2012        );
2013    }
2014
2015    #[test]
2016    fn test_local_lambda_changes_weights() {
2017        let mut rng = make_rng();
2018        let mut actor: PcActor = PcActor::new(hybrid_config(0.5), &mut rng).unwrap();
2019        let input = vec![1.0, -1.0, 0.5, -0.5, 0.0, 1.0, -1.0, 0.5, -0.5];
2020        let infer_result = actor.infer(&input);
2021        let weights_before = actor.layers[0].weights.data.clone();
2022        let delta = vec![0.1; 9];
2023        actor.update_weights(&delta, &infer_result, &input, 1.0);
2024        assert_ne!(actor.layers[0].weights.data, weights_before);
2025    }
2026
2027    #[test]
2028    fn test_local_lambda_clips_weights() {
2029        let mut rng = make_rng();
2030        let mut actor: PcActor = PcActor::new(hybrid_config(0.5), &mut rng).unwrap();
2031        let input = vec![1.0; 9];
2032        let infer_result = actor.infer(&input);
2033        let delta = vec![1e6; 9];
2034        actor.update_weights(&delta, &infer_result, &input, 1.0);
2035        for layer in &actor.layers {
2036            for &w in &layer.weights.data {
2037                assert!(
2038                    w.abs() <= WEIGHT_CLIP + 1e-12,
2039                    "Weight {w} exceeds WEIGHT_CLIP"
2040                );
2041            }
2042        }
2043    }
2044
2045    #[test]
2046    fn test_local_lambda_negative_returns_error() {
2047        let mut rng = make_rng();
2048        let config = hybrid_config(-0.1);
2049        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
2050        assert!(result.is_err());
2051    }
2052
2053    #[test]
2054    fn test_local_lambda_above_one_returns_error() {
2055        let mut rng = make_rng();
2056        let config = hybrid_config(1.1);
2057        let result: Result<PcActor, _> = PcActor::new(config, &mut rng);
2058        assert!(result.is_err());
2059    }
2060}