Skip to main content

pc_rl_core/
serializer.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-03-25
4
5//! JSON-based weight persistence for the PC-Actor-Critic agent.
6//!
7//! Provides save/load for complete agent state (weights, config, metadata)
8//! and checkpoint support with auto-named files.
9//!
10//! Serialization always goes through CPU types (`CpuLinAlg`). Generic agents
11//! convert to/from CPU weights via `to_weights()` / `from_weights()`.
12
13use std::path::{Path, PathBuf};
14
15use chrono::Utc;
16use serde::{Deserialize, Serialize};
17
18use crate::error::PcError;
19use crate::layer::Layer;
20use crate::linalg::LinAlg;
21use crate::mlp_critic::MlpCritic;
22use crate::pc_actor::PcActor;
23use crate::pc_actor_critic::{PcActorCritic, PcActorCriticConfig};
24
25/// Metadata embedded in every save file.
26///
27/// Tracks version, creation timestamp, episode count, and optional
28/// training metrics for provenance.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct AgentMetadata {
31    /// Crate version string.
32    pub version: String,
33    /// UTC timestamp of when the file was created.
34    pub created: String,
35    /// Episode number at time of save.
36    pub episode: usize,
37    /// Optional training statistics snapshot.
38    pub metrics: Option<TrainingMetrics>,
39}
40
41/// Training statistics snapshot for inclusion in save files.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct TrainingMetrics {
44    /// Fraction of games won.
45    pub win_rate: f64,
46    /// Fraction of games lost.
47    pub loss_rate: f64,
48    /// Fraction of games drawn.
49    pub draw_rate: f64,
50    /// Average surprise score over recent episodes.
51    pub avg_surprise: f64,
52    /// Current curriculum depth level.
53    pub curriculum_depth: usize,
54}
55
56/// Serializable weight snapshot for the PC actor.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct PcActorWeights {
59    /// Layer snapshots in order (hidden layers + output layer).
60    pub layers: Vec<Layer>,
61    /// ReZero scaling factors for residual skip connections.
62    #[serde(default)]
63    pub rezero_alpha: Vec<f64>,
64    /// Projection matrices for heterogeneous skip connections.
65    #[serde(default)]
66    pub skip_projections: Vec<Option<crate::matrix::Matrix>>,
67}
68
69/// Complete save file containing agent state and metadata.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SaveFile {
72    /// File metadata (version, timestamp, episode).
73    pub metadata: AgentMetadata,
74    /// Agent configuration.
75    pub config: PcActorCriticConfig,
76    /// Actor network weights.
77    pub actor_weights: PcActorWeights,
78    /// Critic network weights.
79    pub critic_weights: crate::mlp_critic::MlpCriticWeights,
80}
81
82/// Saves the agent's full state to a JSON file.
83///
84/// Creates parent directories if they don't exist. Extracts weights
85/// from both actor and critic via `to_weights()`, bundles with config
86/// and metadata, and writes as pretty-printed JSON.
87///
88/// # Arguments
89///
90/// * `agent` - The agent to save (any `LinAlg` backend).
91/// * `path` - File path for the JSON output.
92/// * `episode` - Current episode number.
93/// * `metrics` - Optional training metrics snapshot.
94///
95/// # Errors
96///
97/// Returns `PcError::Io` on file system errors, `PcError::Serialization`
98/// on JSON encoding errors.
99pub fn save_agent<L: LinAlg>(
100    agent: &PcActorCritic<L>,
101    path: &str,
102    episode: usize,
103    metrics: Option<TrainingMetrics>,
104) -> Result<(), PcError> {
105    let save_file = SaveFile {
106        metadata: AgentMetadata {
107            version: env!("CARGO_PKG_VERSION").to_string(),
108            created: Utc::now().to_rfc3339(),
109            episode,
110            metrics,
111        },
112        config: agent.config.clone(),
113        actor_weights: agent.actor.to_weights(),
114        critic_weights: agent.critic.to_weights(),
115    };
116
117    let json = serde_json::to_string_pretty(&save_file)?;
118
119    // Create parent directories if needed
120    let path = Path::new(path);
121    if let Some(parent) = path.parent() {
122        if !parent.as_os_str().is_empty() {
123            std::fs::create_dir_all(parent)?;
124        }
125    }
126
127    std::fs::write(path, json)?;
128    Ok(())
129}
130
131/// Loads an agent from a JSON save file (CPU backend).
132///
133/// Reads the file, deserializes the `SaveFile`, validates that the
134/// topology matches the config, then reconstructs the agent using
135/// `CpuLinAlg` (the default backend).
136///
137/// # Arguments
138///
139/// * `path` - Path to the JSON save file.
140///
141/// # Errors
142///
143/// Returns `PcError::Io` if the file doesn't exist, `PcError::Serialization`
144/// for invalid JSON, or `PcError::DimensionMismatch` if the saved weights
145/// don't match the config topology.
146pub fn load_agent(path: &str) -> Result<(PcActorCritic, AgentMetadata), PcError> {
147    let json = std::fs::read_to_string(path)?;
148    let save_file: SaveFile = serde_json::from_str(&json)?;
149
150    // Validate actor layer count
151    let expected_actor_layers = save_file.config.actor.hidden_layers.len() + 1;
152    if save_file.actor_weights.layers.len() != expected_actor_layers {
153        return Err(PcError::DimensionMismatch {
154            expected: expected_actor_layers,
155            got: save_file.actor_weights.layers.len(),
156            context: "actor layer count",
157        });
158    }
159
160    // Validate critic layer count
161    let expected_critic_layers = save_file.config.critic.hidden_layers.len() + 1;
162    if save_file.critic_weights.layers.len() != expected_critic_layers {
163        return Err(PcError::DimensionMismatch {
164            expected: expected_critic_layers,
165            got: save_file.critic_weights.layers.len(),
166            context: "critic layer count",
167        });
168    }
169
170    let actor = PcActor::from_weights(save_file.config.actor.clone(), save_file.actor_weights);
171    let critic = MlpCritic::from_weights(save_file.config.critic.clone(), save_file.critic_weights);
172
173    use rand::SeedableRng;
174    let rng = rand::rngs::StdRng::from_entropy();
175
176    let agent = PcActorCritic::from_parts(save_file.config, actor, critic, rng);
177
178    Ok((agent, save_file.metadata))
179}
180
181/// Loads an agent from a JSON save file with a specific `LinAlg` backend.
182///
183/// Same as [`load_agent`] but reconstructs the agent using the specified
184/// backend type `L`. Weights are deserialized as CPU types and then
185/// converted via `PcActor::<L>::from_weights()` and
186/// `MlpCritic::<L>::from_weights()`.
187///
188/// # Arguments
189///
190/// * `path` - Path to the JSON save file.
191///
192/// # Errors
193///
194/// Returns `PcError::Io` if the file doesn't exist, `PcError::Serialization`
195/// for invalid JSON, or `PcError::DimensionMismatch` if the saved weights
196/// don't match the config topology.
197pub fn load_agent_generic<L: LinAlg>(
198    path: &str,
199) -> Result<(PcActorCritic<L>, AgentMetadata), PcError> {
200    let json = std::fs::read_to_string(path)?;
201    let save_file: SaveFile = serde_json::from_str(&json)?;
202
203    // Validate actor layer count
204    let expected_actor_layers = save_file.config.actor.hidden_layers.len() + 1;
205    if save_file.actor_weights.layers.len() != expected_actor_layers {
206        return Err(PcError::DimensionMismatch {
207            expected: expected_actor_layers,
208            got: save_file.actor_weights.layers.len(),
209            context: "actor layer count",
210        });
211    }
212
213    // Validate critic layer count
214    let expected_critic_layers = save_file.config.critic.hidden_layers.len() + 1;
215    if save_file.critic_weights.layers.len() != expected_critic_layers {
216        return Err(PcError::DimensionMismatch {
217            expected: expected_critic_layers,
218            got: save_file.critic_weights.layers.len(),
219            context: "critic layer count",
220        });
221    }
222
223    let actor = PcActor::<L>::from_weights(save_file.config.actor.clone(), save_file.actor_weights);
224    let critic =
225        MlpCritic::<L>::from_weights(save_file.config.critic.clone(), save_file.critic_weights);
226
227    use rand::SeedableRng;
228    let rng = rand::rngs::StdRng::from_entropy();
229
230    let agent = PcActorCritic::from_parts(save_file.config, actor, critic, rng);
231
232    Ok((agent, save_file.metadata))
233}
234
235/// Generates a checkpoint filename with no colons (filesystem-safe).
236///
237/// Format: `checkpoint_ep{N}_{YYYYMMDD_HHMMSS}.json`
238///
239/// # Arguments
240///
241/// * `episode` - Episode number to embed in the filename.
242///
243/// # Examples
244///
245/// ```
246/// use pc_rl_core::serializer::checkpoint_filename;
247///
248/// let name = checkpoint_filename(100);
249/// assert!(name.starts_with("checkpoint_ep100_"));
250/// assert!(name.ends_with(".json"));
251/// assert!(!name.contains(':'));
252/// ```
253pub fn checkpoint_filename(episode: usize) -> String {
254    let now = Utc::now().format("%Y%m%d_%H%M%S");
255    format!("checkpoint_ep{episode}_{now}.json")
256}
257
258/// Saves a checkpoint to a directory with an auto-generated filename.
259///
260/// # Arguments
261///
262/// * `agent` - The agent to checkpoint (any `LinAlg` backend).
263/// * `dir` - Directory where the checkpoint file will be created.
264/// * `episode` - Current episode number.
265/// * `metrics` - Optional training metrics snapshot.
266///
267/// # Returns
268///
269/// The full path to the created checkpoint file.
270///
271/// # Errors
272///
273/// Returns `PcError` on I/O or serialization failures.
274pub fn save_checkpoint<L: LinAlg>(
275    agent: &PcActorCritic<L>,
276    dir: &str,
277    episode: usize,
278    metrics: Option<TrainingMetrics>,
279) -> Result<PathBuf, PcError> {
280    let filename = checkpoint_filename(episode);
281    let path = Path::new(dir).join(filename);
282    let path_str = path.to_string_lossy().to_string();
283    save_agent(agent, &path_str, episode, metrics)?;
284    Ok(path)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::activation::Activation;
291    use crate::layer::LayerDef;
292    use crate::mlp_critic::MlpCriticConfig;
293    use crate::pc_actor::PcActorConfig;
294    use std::fs;
295
296    fn default_config() -> PcActorCriticConfig {
297        PcActorCriticConfig {
298            actor: PcActorConfig {
299                input_size: 9,
300                hidden_layers: vec![LayerDef {
301                    size: 18,
302                    activation: Activation::Tanh,
303                }],
304                output_size: 9,
305                output_activation: Activation::Tanh,
306                alpha: 0.1,
307                tol: 0.01,
308                min_steps: 1,
309                max_steps: 20,
310                lr_weights: 0.01,
311                synchronous: true,
312                temperature: 1.0,
313                local_lambda: 1.0,
314                residual: false,
315                rezero_init: 0.001,
316            },
317            critic: MlpCriticConfig {
318                input_size: 27,
319                hidden_layers: vec![LayerDef {
320                    size: 36,
321                    activation: Activation::Tanh,
322                }],
323                output_activation: Activation::Linear,
324                lr: 0.005,
325            },
326            gamma: 0.95,
327            surprise_low: 0.02,
328            surprise_high: 0.15,
329            adaptive_surprise: false,
330            entropy_coeff: 0.01,
331        }
332    }
333
334    fn make_agent() -> PcActorCritic {
335        let agent: PcActorCritic = PcActorCritic::new(default_config(), 42).unwrap();
336        agent
337    }
338
339    fn temp_path(name: &str) -> String {
340        let dir = std::env::temp_dir().join("pc_core_tests");
341        fs::create_dir_all(&dir).unwrap();
342        dir.join(name).to_string_lossy().to_string()
343    }
344
345    /// Asserts two f64 slices are approximately equal (within 1e-15).
346    fn assert_vecs_approx_eq(a: &[f64], b: &[f64]) {
347        assert_eq!(
348            a.len(),
349            b.len(),
350            "Lengths differ: {} vs {}",
351            a.len(),
352            b.len()
353        );
354        for (i, (va, vb)) in a.iter().zip(b.iter()).enumerate() {
355            assert!((va - vb).abs() < 1e-15, "Element {i} differs: {va} vs {vb}");
356        }
357    }
358
359    #[test]
360    fn test_roundtrip_preserves_actor_weights() {
361        let agent = make_agent();
362        let path = temp_path("test_actor_roundtrip.json");
363        save_agent(&agent, &path, 10, None).unwrap();
364        let (loaded, _) = load_agent(&path).unwrap();
365        for (orig, loaded_layer) in agent.actor.layers.iter().zip(loaded.actor.layers.iter()) {
366            assert_vecs_approx_eq(&orig.weights.data, &loaded_layer.weights.data);
367            assert_vecs_approx_eq(&orig.bias, &loaded_layer.bias);
368        }
369        let _ = fs::remove_file(&path);
370    }
371
372    #[test]
373    fn test_roundtrip_preserves_critic_weights() {
374        let agent = make_agent();
375        let path = temp_path("test_critic_roundtrip.json");
376        save_agent(&agent, &path, 10, None).unwrap();
377        let (loaded, _) = load_agent(&path).unwrap();
378        for (orig, loaded_layer) in agent.critic.layers.iter().zip(loaded.critic.layers.iter()) {
379            assert_vecs_approx_eq(&orig.weights.data, &loaded_layer.weights.data);
380            assert_vecs_approx_eq(&orig.bias, &loaded_layer.bias);
381        }
382        let _ = fs::remove_file(&path);
383    }
384
385    #[test]
386    fn test_roundtrip_preserves_config() {
387        let agent = make_agent();
388        let path = temp_path("test_config_roundtrip.json");
389        save_agent(&agent, &path, 10, None).unwrap();
390        let (loaded, _) = load_agent(&path).unwrap();
391        assert_eq!(loaded.config.gamma, agent.config.gamma);
392        assert_eq!(
393            loaded.config.actor.input_size,
394            agent.config.actor.input_size
395        );
396        assert_eq!(
397            loaded.config.critic.input_size,
398            agent.config.critic.input_size
399        );
400        assert_eq!(loaded.config.entropy_coeff, agent.config.entropy_coeff);
401        let _ = fs::remove_file(&path);
402    }
403
404    #[test]
405    fn test_metadata_includes_version_and_episode() {
406        let agent = make_agent();
407        let path = temp_path("test_metadata.json");
408        save_agent(&agent, &path, 42, None).unwrap();
409        let (_, metadata) = load_agent(&path).unwrap();
410        assert!(!metadata.version.is_empty());
411        assert_eq!(metadata.episode, 42);
412        assert!(!metadata.created.is_empty());
413        let _ = fs::remove_file(&path);
414    }
415
416    #[test]
417    fn test_checkpoint_filename_no_colons() {
418        let name = checkpoint_filename(100);
419        assert!(!name.contains(':'), "Filename contains colons: {name}");
420    }
421
422    #[test]
423    fn test_checkpoint_filename_contains_episode_number() {
424        let name = checkpoint_filename(42);
425        assert!(
426            name.contains("ep42"),
427            "Filename doesn't contain episode number: {name}"
428        );
429        assert!(name.ends_with(".json"));
430    }
431
432    #[test]
433    fn test_load_nonexistent_returns_error() {
434        let result = load_agent("/nonexistent/path/agent.json");
435        assert!(result.is_err());
436        let err = result.err().unwrap();
437        assert!(
438            matches!(err, PcError::Io(_)),
439            "Expected PcError::Io, got: {err}"
440        );
441    }
442
443    #[test]
444    fn test_load_invalid_json_returns_error() {
445        let path = temp_path("test_invalid.json");
446        fs::write(&path, "not valid json {{{").unwrap();
447        let result = load_agent(&path);
448        assert!(result.is_err());
449        let err = result.err().unwrap();
450        assert!(
451            matches!(err, PcError::Serialization(_)),
452            "Expected PcError::Serialization, got: {err}"
453        );
454        let _ = fs::remove_file(&path);
455    }
456
457    #[test]
458    fn test_load_topology_mismatch_returns_error() {
459        let agent = make_agent();
460        let path = temp_path("test_mismatch.json");
461        save_agent(&agent, &path, 0, None).unwrap();
462
463        // Tamper: read JSON, change actor layer count in config
464        let json = fs::read_to_string(&path).unwrap();
465        let mut save_file: SaveFile = serde_json::from_str(&json).unwrap();
466        // Add an extra hidden layer to config (but not weights)
467        save_file.config.actor.hidden_layers.push(LayerDef {
468            size: 10,
469            activation: Activation::Relu,
470        });
471        let tampered = serde_json::to_string_pretty(&save_file).unwrap();
472        fs::write(&path, tampered).unwrap();
473
474        let result = load_agent(&path);
475        assert!(result.is_err());
476        let err = result.err().unwrap();
477        assert!(
478            matches!(err, PcError::DimensionMismatch { .. }),
479            "Expected PcError::DimensionMismatch, got: {err}"
480        );
481        let _ = fs::remove_file(&path);
482    }
483
484    #[test]
485    fn test_load_agent_uses_entropy_seed_not_fixed() {
486        let agent = make_agent();
487        let path = temp_path("test_seed_entropy.json");
488        save_agent(&agent, &path, 10, None).unwrap();
489
490        let (mut loaded1, _) = load_agent(&path).unwrap();
491        let (mut loaded2, _) = load_agent(&path).unwrap();
492
493        // Both agents should produce different action sequences
494        // because they use entropy-based RNG seeding
495        let input = vec![0.5; 9];
496        let valid: Vec<usize> = (0..9).collect();
497
498        let mut actions1 = Vec::new();
499        let mut actions2 = Vec::new();
500        for _ in 0..20 {
501            let (a1, _) = loaded1.act(&input, &valid, crate::pc_actor::SelectionMode::Training);
502            let (a2, _) = loaded2.act(&input, &valid, crate::pc_actor::SelectionMode::Training);
503            actions1.push(a1);
504            actions2.push(a2);
505        }
506
507        assert_ne!(
508            actions1, actions2,
509            "Two loaded agents should have different exploration due to entropy seeding"
510        );
511        let _ = fs::remove_file(&path);
512    }
513
514    #[test]
515    fn test_loaded_agent_produces_identical_inference() {
516        let agent = make_agent();
517        let path = temp_path("test_identical_infer.json");
518        save_agent(&agent, &path, 10, None).unwrap();
519        let (loaded, _) = load_agent(&path).unwrap();
520
521        let input = vec![0.5, -0.5, 1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0];
522        let orig_result = agent.infer(&input);
523        let loaded_result = loaded.infer(&input);
524
525        // y_conv must be identical
526        assert_eq!(orig_result.y_conv.len(), loaded_result.y_conv.len());
527        for (a, b) in orig_result.y_conv.iter().zip(loaded_result.y_conv.iter()) {
528            assert!((a - b).abs() < 1e-12, "y_conv differs: {a} vs {b}");
529        }
530        // latent_concat must be identical
531        for (a, b) in orig_result
532            .latent_concat
533            .iter()
534            .zip(loaded_result.latent_concat.iter())
535        {
536            assert!((a - b).abs() < 1e-12, "latent_concat differs: {a} vs {b}");
537        }
538        let _ = fs::remove_file(&path);
539    }
540
541    #[test]
542    fn test_save_creates_parent_directories() {
543        let dir = std::env::temp_dir()
544            .join("pc_core_tests")
545            .join("nested")
546            .join("deep");
547        let path = dir.join("agent.json").to_string_lossy().to_string();
548
549        // Remove if exists from prior run
550        let _ = fs::remove_dir_all(&dir);
551
552        let agent = make_agent();
553        save_agent(&agent, &path, 0, None).unwrap();
554        assert!(Path::new(&path).exists());
555
556        // Cleanup
557        let _ = fs::remove_dir_all(std::env::temp_dir().join("pc_core_tests").join("nested"));
558    }
559
560    #[test]
561    fn test_roundtrip_preserves_modified_rezero_alpha() {
562        use crate::pc_actor::SelectionMode;
563        let config = PcActorCriticConfig {
564            actor: PcActorConfig {
565                residual: true,
566                rezero_init: 0.005,
567                hidden_layers: vec![
568                    LayerDef {
569                        size: 27,
570                        activation: Activation::Tanh,
571                    },
572                    LayerDef {
573                        size: 27,
574                        activation: Activation::Tanh,
575                    },
576                ],
577                ..default_config().actor
578            },
579            critic: MlpCriticConfig {
580                input_size: 63,
581                ..default_config().critic
582            },
583            ..default_config()
584        };
585        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
586        // Train one step to modify rezero_alpha
587        let input = vec![0.5; 9];
588        let valid: Vec<usize> = (0..9).collect();
589        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
590        let trajectory = vec![crate::pc_actor_critic::TrajectoryStep {
591            input: input.clone(),
592            latent_concat: infer.latent_concat,
593            y_conv: infer.y_conv,
594            hidden_states: infer.hidden_states,
595            prediction_errors: infer.prediction_errors,
596            tanh_components: infer.tanh_components,
597            action,
598            valid_actions: valid,
599            reward: 1.0,
600            surprise_score: infer.surprise_score,
601            steps_used: infer.steps_used,
602        }];
603        agent.learn(&trajectory);
604        let alpha_after_train = agent.actor.rezero_alpha.clone();
605        // Alpha should have changed from init
606        assert_ne!(alpha_after_train, vec![0.005]);
607
608        let path = temp_path("test_rezero_roundtrip.json");
609        save_agent(&agent, &path, 10, None).unwrap();
610        let (loaded, _) = load_agent(&path).unwrap();
611        assert_eq!(
612            loaded.actor.rezero_alpha, alpha_after_train,
613            "Loaded rezero_alpha should match trained value, not rezero_init"
614        );
615        let _ = fs::remove_file(&path);
616    }
617
618    #[test]
619    fn test_roundtrip_non_residual_backward_compat() {
620        let agent = make_agent();
621        assert!(agent.actor.rezero_alpha.is_empty());
622
623        let path = temp_path("test_nonresidual_compat.json");
624        save_agent(&agent, &path, 10, None).unwrap();
625        let (loaded, _) = load_agent(&path).unwrap();
626        assert!(loaded.actor.rezero_alpha.is_empty());
627        let _ = fs::remove_file(&path);
628    }
629
630    #[test]
631    fn test_load_agent_generic_matches_load_agent() {
632        let agent = make_agent();
633        let path = temp_path("test_generic_load.json");
634        save_agent(&agent, &path, 10, None).unwrap();
635
636        let (loaded_default, _) = load_agent(&path).unwrap();
637        let (loaded_generic, _) =
638            load_agent_generic::<crate::linalg::cpu::CpuLinAlg>(&path).unwrap();
639
640        let input = vec![0.5, -0.5, 1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0];
641        let r1 = loaded_default.infer(&input);
642        let r2 = loaded_generic.infer(&input);
643
644        for (a, b) in r1.y_conv.iter().zip(r2.y_conv.iter()) {
645            assert!((a - b).abs() < 1e-15, "y_conv differs: {a} vs {b}");
646        }
647        let _ = fs::remove_file(&path);
648    }
649
650    #[test]
651    fn test_roundtrip_preserves_skip_projections_directly() {
652        use crate::pc_actor::SelectionMode;
653        let config = PcActorCriticConfig {
654            actor: PcActorConfig {
655                residual: true,
656                rezero_init: 0.005,
657                hidden_layers: vec![
658                    LayerDef {
659                        size: 27,
660                        activation: Activation::Tanh,
661                    },
662                    LayerDef {
663                        size: 18,
664                        activation: Activation::Tanh,
665                    },
666                ],
667                ..default_config().actor
668            },
669            critic: MlpCriticConfig {
670                input_size: 54,
671                ..default_config().critic
672            },
673            ..default_config()
674        };
675        let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
676        // Train to modify projection weights
677        let input = vec![0.5; 9];
678        let valid: Vec<usize> = (0..9).collect();
679        let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
680        let trajectory = vec![crate::pc_actor_critic::TrajectoryStep {
681            input: input.clone(),
682            latent_concat: infer.latent_concat,
683            y_conv: infer.y_conv,
684            hidden_states: infer.hidden_states,
685            prediction_errors: infer.prediction_errors,
686            tanh_components: infer.tanh_components,
687            action,
688            valid_actions: valid,
689            reward: 1.0,
690            surprise_score: infer.surprise_score,
691            steps_used: infer.steps_used,
692        }];
693        agent.learn(&trajectory);
694
695        // Verify projection exists (27→18 requires projection)
696        assert!(agent.actor.skip_projections[0].is_some());
697        let orig_proj = agent.actor.skip_projections[0].as_ref().unwrap();
698        let orig_data = orig_proj.data.clone();
699
700        let path = temp_path("test_skip_proj_roundtrip.json");
701        save_agent(&agent, &path, 10, None).unwrap();
702        let (loaded, _) = load_agent(&path).unwrap();
703
704        let loaded_proj = loaded.actor.skip_projections[0].as_ref().unwrap();
705        assert_eq!(orig_data.len(), loaded_proj.data.len());
706        for (i, (a, b)) in orig_data.iter().zip(loaded_proj.data.iter()).enumerate() {
707            assert!(
708                (a - b).abs() < 1e-15,
709                "skip_projection element {i} differs: {a} vs {b}"
710            );
711        }
712        let _ = fs::remove_file(&path);
713    }
714}