1use std::path::{Path, PathBuf};
14
15use chrono::Utc;
16use serde::{Deserialize, Serialize};
17
18use crate::error::PcError;
19use crate::layer::Layer;
20use crate::linalg::LinAlg;
21use crate::mlp_critic::MlpCritic;
22use crate::pc_actor::PcActor;
23use crate::pc_actor_critic::{PcActorCritic, PcActorCriticConfig};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct AgentMetadata {
31 pub version: String,
33 pub created: String,
35 pub episode: usize,
37 pub metrics: Option<TrainingMetrics>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct TrainingMetrics {
44 pub win_rate: f64,
46 pub loss_rate: f64,
48 pub draw_rate: f64,
50 pub avg_surprise: f64,
52 pub curriculum_depth: usize,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct PcActorWeights {
59 pub layers: Vec<Layer>,
61 #[serde(default)]
63 pub rezero_alpha: Vec<f64>,
64 #[serde(default)]
66 pub skip_projections: Vec<Option<crate::matrix::Matrix>>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SaveFile {
72 pub metadata: AgentMetadata,
74 pub config: PcActorCriticConfig,
76 pub actor_weights: PcActorWeights,
78 pub critic_weights: crate::mlp_critic::MlpCriticWeights,
80}
81
82pub fn save_agent<L: LinAlg>(
100 agent: &PcActorCritic<L>,
101 path: &str,
102 episode: usize,
103 metrics: Option<TrainingMetrics>,
104) -> Result<(), PcError> {
105 let save_file = SaveFile {
106 metadata: AgentMetadata {
107 version: env!("CARGO_PKG_VERSION").to_string(),
108 created: Utc::now().to_rfc3339(),
109 episode,
110 metrics,
111 },
112 config: agent.config.clone(),
113 actor_weights: agent.actor.to_weights(),
114 critic_weights: agent.critic.to_weights(),
115 };
116
117 let json = serde_json::to_string_pretty(&save_file)?;
118
119 let path = Path::new(path);
121 if let Some(parent) = path.parent() {
122 if !parent.as_os_str().is_empty() {
123 std::fs::create_dir_all(parent)?;
124 }
125 }
126
127 std::fs::write(path, json)?;
128 Ok(())
129}
130
131pub fn load_agent(path: &str) -> Result<(PcActorCritic, AgentMetadata), PcError> {
147 let json = std::fs::read_to_string(path)?;
148 let save_file: SaveFile = serde_json::from_str(&json)?;
149
150 let expected_actor_layers = save_file.config.actor.hidden_layers.len() + 1;
152 if save_file.actor_weights.layers.len() != expected_actor_layers {
153 return Err(PcError::DimensionMismatch {
154 expected: expected_actor_layers,
155 got: save_file.actor_weights.layers.len(),
156 context: "actor layer count",
157 });
158 }
159
160 let expected_critic_layers = save_file.config.critic.hidden_layers.len() + 1;
162 if save_file.critic_weights.layers.len() != expected_critic_layers {
163 return Err(PcError::DimensionMismatch {
164 expected: expected_critic_layers,
165 got: save_file.critic_weights.layers.len(),
166 context: "critic layer count",
167 });
168 }
169
170 let actor = PcActor::from_weights(save_file.config.actor.clone(), save_file.actor_weights);
171 let critic = MlpCritic::from_weights(save_file.config.critic.clone(), save_file.critic_weights);
172
173 use rand::SeedableRng;
174 let rng = rand::rngs::StdRng::from_entropy();
175
176 let agent = PcActorCritic::from_parts(save_file.config, actor, critic, rng);
177
178 Ok((agent, save_file.metadata))
179}
180
181pub fn load_agent_generic<L: LinAlg>(
198 path: &str,
199) -> Result<(PcActorCritic<L>, AgentMetadata), PcError> {
200 let json = std::fs::read_to_string(path)?;
201 let save_file: SaveFile = serde_json::from_str(&json)?;
202
203 let expected_actor_layers = save_file.config.actor.hidden_layers.len() + 1;
205 if save_file.actor_weights.layers.len() != expected_actor_layers {
206 return Err(PcError::DimensionMismatch {
207 expected: expected_actor_layers,
208 got: save_file.actor_weights.layers.len(),
209 context: "actor layer count",
210 });
211 }
212
213 let expected_critic_layers = save_file.config.critic.hidden_layers.len() + 1;
215 if save_file.critic_weights.layers.len() != expected_critic_layers {
216 return Err(PcError::DimensionMismatch {
217 expected: expected_critic_layers,
218 got: save_file.critic_weights.layers.len(),
219 context: "critic layer count",
220 });
221 }
222
223 let actor = PcActor::<L>::from_weights(save_file.config.actor.clone(), save_file.actor_weights);
224 let critic =
225 MlpCritic::<L>::from_weights(save_file.config.critic.clone(), save_file.critic_weights);
226
227 use rand::SeedableRng;
228 let rng = rand::rngs::StdRng::from_entropy();
229
230 let agent = PcActorCritic::from_parts(save_file.config, actor, critic, rng);
231
232 Ok((agent, save_file.metadata))
233}
234
235pub fn checkpoint_filename(episode: usize) -> String {
254 let now = Utc::now().format("%Y%m%d_%H%M%S");
255 format!("checkpoint_ep{episode}_{now}.json")
256}
257
258pub fn save_checkpoint<L: LinAlg>(
275 agent: &PcActorCritic<L>,
276 dir: &str,
277 episode: usize,
278 metrics: Option<TrainingMetrics>,
279) -> Result<PathBuf, PcError> {
280 let filename = checkpoint_filename(episode);
281 let path = Path::new(dir).join(filename);
282 let path_str = path.to_string_lossy().to_string();
283 save_agent(agent, &path_str, episode, metrics)?;
284 Ok(path)
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::activation::Activation;
291 use crate::layer::LayerDef;
292 use crate::mlp_critic::MlpCriticConfig;
293 use crate::pc_actor::PcActorConfig;
294 use std::fs;
295
296 fn default_config() -> PcActorCriticConfig {
297 PcActorCriticConfig {
298 actor: PcActorConfig {
299 input_size: 9,
300 hidden_layers: vec![LayerDef {
301 size: 18,
302 activation: Activation::Tanh,
303 }],
304 output_size: 9,
305 output_activation: Activation::Tanh,
306 alpha: 0.1,
307 tol: 0.01,
308 min_steps: 1,
309 max_steps: 20,
310 lr_weights: 0.01,
311 synchronous: true,
312 temperature: 1.0,
313 local_lambda: 1.0,
314 residual: false,
315 rezero_init: 0.001,
316 },
317 critic: MlpCriticConfig {
318 input_size: 27,
319 hidden_layers: vec![LayerDef {
320 size: 36,
321 activation: Activation::Tanh,
322 }],
323 output_activation: Activation::Linear,
324 lr: 0.005,
325 },
326 gamma: 0.95,
327 surprise_low: 0.02,
328 surprise_high: 0.15,
329 adaptive_surprise: false,
330 entropy_coeff: 0.01,
331 }
332 }
333
334 fn make_agent() -> PcActorCritic {
335 let agent: PcActorCritic = PcActorCritic::new(default_config(), 42).unwrap();
336 agent
337 }
338
339 fn temp_path(name: &str) -> String {
340 let dir = std::env::temp_dir().join("pc_core_tests");
341 fs::create_dir_all(&dir).unwrap();
342 dir.join(name).to_string_lossy().to_string()
343 }
344
345 fn assert_vecs_approx_eq(a: &[f64], b: &[f64]) {
347 assert_eq!(
348 a.len(),
349 b.len(),
350 "Lengths differ: {} vs {}",
351 a.len(),
352 b.len()
353 );
354 for (i, (va, vb)) in a.iter().zip(b.iter()).enumerate() {
355 assert!((va - vb).abs() < 1e-15, "Element {i} differs: {va} vs {vb}");
356 }
357 }
358
359 #[test]
360 fn test_roundtrip_preserves_actor_weights() {
361 let agent = make_agent();
362 let path = temp_path("test_actor_roundtrip.json");
363 save_agent(&agent, &path, 10, None).unwrap();
364 let (loaded, _) = load_agent(&path).unwrap();
365 for (orig, loaded_layer) in agent.actor.layers.iter().zip(loaded.actor.layers.iter()) {
366 assert_vecs_approx_eq(&orig.weights.data, &loaded_layer.weights.data);
367 assert_vecs_approx_eq(&orig.bias, &loaded_layer.bias);
368 }
369 let _ = fs::remove_file(&path);
370 }
371
372 #[test]
373 fn test_roundtrip_preserves_critic_weights() {
374 let agent = make_agent();
375 let path = temp_path("test_critic_roundtrip.json");
376 save_agent(&agent, &path, 10, None).unwrap();
377 let (loaded, _) = load_agent(&path).unwrap();
378 for (orig, loaded_layer) in agent.critic.layers.iter().zip(loaded.critic.layers.iter()) {
379 assert_vecs_approx_eq(&orig.weights.data, &loaded_layer.weights.data);
380 assert_vecs_approx_eq(&orig.bias, &loaded_layer.bias);
381 }
382 let _ = fs::remove_file(&path);
383 }
384
385 #[test]
386 fn test_roundtrip_preserves_config() {
387 let agent = make_agent();
388 let path = temp_path("test_config_roundtrip.json");
389 save_agent(&agent, &path, 10, None).unwrap();
390 let (loaded, _) = load_agent(&path).unwrap();
391 assert_eq!(loaded.config.gamma, agent.config.gamma);
392 assert_eq!(
393 loaded.config.actor.input_size,
394 agent.config.actor.input_size
395 );
396 assert_eq!(
397 loaded.config.critic.input_size,
398 agent.config.critic.input_size
399 );
400 assert_eq!(loaded.config.entropy_coeff, agent.config.entropy_coeff);
401 let _ = fs::remove_file(&path);
402 }
403
404 #[test]
405 fn test_metadata_includes_version_and_episode() {
406 let agent = make_agent();
407 let path = temp_path("test_metadata.json");
408 save_agent(&agent, &path, 42, None).unwrap();
409 let (_, metadata) = load_agent(&path).unwrap();
410 assert!(!metadata.version.is_empty());
411 assert_eq!(metadata.episode, 42);
412 assert!(!metadata.created.is_empty());
413 let _ = fs::remove_file(&path);
414 }
415
416 #[test]
417 fn test_checkpoint_filename_no_colons() {
418 let name = checkpoint_filename(100);
419 assert!(!name.contains(':'), "Filename contains colons: {name}");
420 }
421
422 #[test]
423 fn test_checkpoint_filename_contains_episode_number() {
424 let name = checkpoint_filename(42);
425 assert!(
426 name.contains("ep42"),
427 "Filename doesn't contain episode number: {name}"
428 );
429 assert!(name.ends_with(".json"));
430 }
431
432 #[test]
433 fn test_load_nonexistent_returns_error() {
434 let result = load_agent("/nonexistent/path/agent.json");
435 assert!(result.is_err());
436 let err = result.err().unwrap();
437 assert!(
438 matches!(err, PcError::Io(_)),
439 "Expected PcError::Io, got: {err}"
440 );
441 }
442
443 #[test]
444 fn test_load_invalid_json_returns_error() {
445 let path = temp_path("test_invalid.json");
446 fs::write(&path, "not valid json {{{").unwrap();
447 let result = load_agent(&path);
448 assert!(result.is_err());
449 let err = result.err().unwrap();
450 assert!(
451 matches!(err, PcError::Serialization(_)),
452 "Expected PcError::Serialization, got: {err}"
453 );
454 let _ = fs::remove_file(&path);
455 }
456
457 #[test]
458 fn test_load_topology_mismatch_returns_error() {
459 let agent = make_agent();
460 let path = temp_path("test_mismatch.json");
461 save_agent(&agent, &path, 0, None).unwrap();
462
463 let json = fs::read_to_string(&path).unwrap();
465 let mut save_file: SaveFile = serde_json::from_str(&json).unwrap();
466 save_file.config.actor.hidden_layers.push(LayerDef {
468 size: 10,
469 activation: Activation::Relu,
470 });
471 let tampered = serde_json::to_string_pretty(&save_file).unwrap();
472 fs::write(&path, tampered).unwrap();
473
474 let result = load_agent(&path);
475 assert!(result.is_err());
476 let err = result.err().unwrap();
477 assert!(
478 matches!(err, PcError::DimensionMismatch { .. }),
479 "Expected PcError::DimensionMismatch, got: {err}"
480 );
481 let _ = fs::remove_file(&path);
482 }
483
484 #[test]
485 fn test_load_agent_uses_entropy_seed_not_fixed() {
486 let agent = make_agent();
487 let path = temp_path("test_seed_entropy.json");
488 save_agent(&agent, &path, 10, None).unwrap();
489
490 let (mut loaded1, _) = load_agent(&path).unwrap();
491 let (mut loaded2, _) = load_agent(&path).unwrap();
492
493 let input = vec![0.5; 9];
496 let valid: Vec<usize> = (0..9).collect();
497
498 let mut actions1 = Vec::new();
499 let mut actions2 = Vec::new();
500 for _ in 0..20 {
501 let (a1, _) = loaded1.act(&input, &valid, crate::pc_actor::SelectionMode::Training);
502 let (a2, _) = loaded2.act(&input, &valid, crate::pc_actor::SelectionMode::Training);
503 actions1.push(a1);
504 actions2.push(a2);
505 }
506
507 assert_ne!(
508 actions1, actions2,
509 "Two loaded agents should have different exploration due to entropy seeding"
510 );
511 let _ = fs::remove_file(&path);
512 }
513
514 #[test]
515 fn test_loaded_agent_produces_identical_inference() {
516 let agent = make_agent();
517 let path = temp_path("test_identical_infer.json");
518 save_agent(&agent, &path, 10, None).unwrap();
519 let (loaded, _) = load_agent(&path).unwrap();
520
521 let input = vec![0.5, -0.5, 1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0];
522 let orig_result = agent.infer(&input);
523 let loaded_result = loaded.infer(&input);
524
525 assert_eq!(orig_result.y_conv.len(), loaded_result.y_conv.len());
527 for (a, b) in orig_result.y_conv.iter().zip(loaded_result.y_conv.iter()) {
528 assert!((a - b).abs() < 1e-12, "y_conv differs: {a} vs {b}");
529 }
530 for (a, b) in orig_result
532 .latent_concat
533 .iter()
534 .zip(loaded_result.latent_concat.iter())
535 {
536 assert!((a - b).abs() < 1e-12, "latent_concat differs: {a} vs {b}");
537 }
538 let _ = fs::remove_file(&path);
539 }
540
541 #[test]
542 fn test_save_creates_parent_directories() {
543 let dir = std::env::temp_dir()
544 .join("pc_core_tests")
545 .join("nested")
546 .join("deep");
547 let path = dir.join("agent.json").to_string_lossy().to_string();
548
549 let _ = fs::remove_dir_all(&dir);
551
552 let agent = make_agent();
553 save_agent(&agent, &path, 0, None).unwrap();
554 assert!(Path::new(&path).exists());
555
556 let _ = fs::remove_dir_all(std::env::temp_dir().join("pc_core_tests").join("nested"));
558 }
559
560 #[test]
561 fn test_roundtrip_preserves_modified_rezero_alpha() {
562 use crate::pc_actor::SelectionMode;
563 let config = PcActorCriticConfig {
564 actor: PcActorConfig {
565 residual: true,
566 rezero_init: 0.005,
567 hidden_layers: vec![
568 LayerDef {
569 size: 27,
570 activation: Activation::Tanh,
571 },
572 LayerDef {
573 size: 27,
574 activation: Activation::Tanh,
575 },
576 ],
577 ..default_config().actor
578 },
579 critic: MlpCriticConfig {
580 input_size: 63,
581 ..default_config().critic
582 },
583 ..default_config()
584 };
585 let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
586 let input = vec![0.5; 9];
588 let valid: Vec<usize> = (0..9).collect();
589 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
590 let trajectory = vec![crate::pc_actor_critic::TrajectoryStep {
591 input: input.clone(),
592 latent_concat: infer.latent_concat,
593 y_conv: infer.y_conv,
594 hidden_states: infer.hidden_states,
595 prediction_errors: infer.prediction_errors,
596 tanh_components: infer.tanh_components,
597 action,
598 valid_actions: valid,
599 reward: 1.0,
600 surprise_score: infer.surprise_score,
601 steps_used: infer.steps_used,
602 }];
603 agent.learn(&trajectory);
604 let alpha_after_train = agent.actor.rezero_alpha.clone();
605 assert_ne!(alpha_after_train, vec![0.005]);
607
608 let path = temp_path("test_rezero_roundtrip.json");
609 save_agent(&agent, &path, 10, None).unwrap();
610 let (loaded, _) = load_agent(&path).unwrap();
611 assert_eq!(
612 loaded.actor.rezero_alpha, alpha_after_train,
613 "Loaded rezero_alpha should match trained value, not rezero_init"
614 );
615 let _ = fs::remove_file(&path);
616 }
617
618 #[test]
619 fn test_roundtrip_non_residual_backward_compat() {
620 let agent = make_agent();
621 assert!(agent.actor.rezero_alpha.is_empty());
622
623 let path = temp_path("test_nonresidual_compat.json");
624 save_agent(&agent, &path, 10, None).unwrap();
625 let (loaded, _) = load_agent(&path).unwrap();
626 assert!(loaded.actor.rezero_alpha.is_empty());
627 let _ = fs::remove_file(&path);
628 }
629
630 #[test]
631 fn test_load_agent_generic_matches_load_agent() {
632 let agent = make_agent();
633 let path = temp_path("test_generic_load.json");
634 save_agent(&agent, &path, 10, None).unwrap();
635
636 let (loaded_default, _) = load_agent(&path).unwrap();
637 let (loaded_generic, _) =
638 load_agent_generic::<crate::linalg::cpu::CpuLinAlg>(&path).unwrap();
639
640 let input = vec![0.5, -0.5, 1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0];
641 let r1 = loaded_default.infer(&input);
642 let r2 = loaded_generic.infer(&input);
643
644 for (a, b) in r1.y_conv.iter().zip(r2.y_conv.iter()) {
645 assert!((a - b).abs() < 1e-15, "y_conv differs: {a} vs {b}");
646 }
647 let _ = fs::remove_file(&path);
648 }
649
650 #[test]
651 fn test_roundtrip_preserves_skip_projections_directly() {
652 use crate::pc_actor::SelectionMode;
653 let config = PcActorCriticConfig {
654 actor: PcActorConfig {
655 residual: true,
656 rezero_init: 0.005,
657 hidden_layers: vec![
658 LayerDef {
659 size: 27,
660 activation: Activation::Tanh,
661 },
662 LayerDef {
663 size: 18,
664 activation: Activation::Tanh,
665 },
666 ],
667 ..default_config().actor
668 },
669 critic: MlpCriticConfig {
670 input_size: 54,
671 ..default_config().critic
672 },
673 ..default_config()
674 };
675 let mut agent: PcActorCritic = PcActorCritic::new(config, 42).unwrap();
676 let input = vec![0.5; 9];
678 let valid: Vec<usize> = (0..9).collect();
679 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
680 let trajectory = vec![crate::pc_actor_critic::TrajectoryStep {
681 input: input.clone(),
682 latent_concat: infer.latent_concat,
683 y_conv: infer.y_conv,
684 hidden_states: infer.hidden_states,
685 prediction_errors: infer.prediction_errors,
686 tanh_components: infer.tanh_components,
687 action,
688 valid_actions: valid,
689 reward: 1.0,
690 surprise_score: infer.surprise_score,
691 steps_used: infer.steps_used,
692 }];
693 agent.learn(&trajectory);
694
695 assert!(agent.actor.skip_projections[0].is_some());
697 let orig_proj = agent.actor.skip_projections[0].as_ref().unwrap();
698 let orig_data = orig_proj.data.clone();
699
700 let path = temp_path("test_skip_proj_roundtrip.json");
701 save_agent(&agent, &path, 10, None).unwrap();
702 let (loaded, _) = load_agent(&path).unwrap();
703
704 let loaded_proj = loaded.actor.skip_projections[0].as_ref().unwrap();
705 assert_eq!(orig_data.len(), loaded_proj.data.len());
706 for (i, (a, b)) in orig_data.iter().zip(loaded_proj.data.iter()).enumerate() {
707 assert!(
708 (a - b).abs() < 1e-15,
709 "skip_projection element {i} differs: {a} vs {b}"
710 );
711 }
712 let _ = fs::remove_file(&path);
713 }
714}