1use std::path::{Path, PathBuf};
14
15use chrono::Utc;
16use serde::{Deserialize, Serialize};
17
18use crate::error::PcError;
19use crate::layer::Layer;
20use crate::linalg::LinAlg;
21use crate::mlp_critic::MlpCritic;
22use crate::pc_actor::PcActor;
23use crate::pc_actor_critic::{PcActorCritic, PcActorCriticConfig};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct AgentMetadata {
31 pub version: String,
33 pub created: String,
35 pub episode: usize,
37 pub metrics: Option<TrainingMetrics>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct TrainingMetrics {
44 pub win_rate: f64,
46 pub loss_rate: f64,
48 pub draw_rate: f64,
50 pub avg_surprise: f64,
52 pub curriculum_depth: usize,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct PcActorWeights {
59 pub layers: Vec<Layer>,
61 #[serde(default)]
63 pub rezero_alpha: Vec<f64>,
64 #[serde(default)]
66 pub skip_projections: Vec<Option<crate::matrix::Matrix>>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct SaveFile {
72 pub metadata: AgentMetadata,
74 pub config: PcActorCriticConfig,
76 pub actor_weights: PcActorWeights,
78 pub critic_weights: crate::mlp_critic::MlpCriticWeights,
80}
81
82pub fn save_agent<L: LinAlg>(
100 agent: &PcActorCritic<L>,
101 path: &str,
102 episode: usize,
103 metrics: Option<TrainingMetrics>,
104) -> Result<(), PcError> {
105 let save_file = SaveFile {
106 metadata: AgentMetadata {
107 version: env!("CARGO_PKG_VERSION").to_string(),
108 created: Utc::now().to_rfc3339(),
109 episode,
110 metrics,
111 },
112 config: agent.config.clone(),
113 actor_weights: agent.actor.to_weights(),
114 critic_weights: agent.critic.to_weights(),
115 };
116
117 let json = serde_json::to_string_pretty(&save_file)?;
118
119 let path = Path::new(path);
121 if let Some(parent) = path.parent() {
122 if !parent.as_os_str().is_empty() {
123 std::fs::create_dir_all(parent)?;
124 }
125 }
126
127 std::fs::write(path, json)?;
128 Ok(())
129}
130
131pub fn load_agent(
147 path: &str,
148 backend: crate::linalg::cpu::CpuLinAlg,
149) -> Result<(PcActorCritic, AgentMetadata), PcError> {
150 load_agent_generic(path, backend)
151}
152
153pub fn load_agent_generic<L: LinAlg>(
170 path: &str,
171 backend: L,
172) -> Result<(PcActorCritic<L>, AgentMetadata), PcError> {
173 let json = std::fs::read_to_string(path)?;
174 let save_file: SaveFile = serde_json::from_str(&json)?;
175
176 let actor = PcActor::<L>::from_weights(
177 backend.clone(),
178 save_file.config.actor.clone(),
179 save_file.actor_weights,
180 )?;
181 let critic = MlpCritic::<L>::from_weights(
182 backend.clone(),
183 save_file.config.critic.clone(),
184 save_file.critic_weights,
185 )?;
186
187 use rand::SeedableRng;
188 let rng = rand::rngs::StdRng::from_entropy();
189
190 let agent = PcActorCritic::from_parts(save_file.config, actor, critic, rng, backend);
191
192 Ok((agent, save_file.metadata))
193}
194
195pub fn checkpoint_filename(episode: usize) -> String {
214 let now = Utc::now().format("%Y%m%d_%H%M%S");
215 format!("checkpoint_ep{episode}_{now}.json")
216}
217
218pub fn save_checkpoint<L: LinAlg>(
235 agent: &PcActorCritic<L>,
236 dir: &str,
237 episode: usize,
238 metrics: Option<TrainingMetrics>,
239) -> Result<PathBuf, PcError> {
240 let filename = checkpoint_filename(episode);
241 let path = Path::new(dir).join(filename);
242 let path_str = path.to_string_lossy().to_string();
243 save_agent(agent, &path_str, episode, metrics)?;
244 Ok(path)
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use crate::activation::Activation;
251 use crate::layer::LayerDef;
252 use crate::mlp_critic::MlpCriticConfig;
253 use crate::pc_actor::PcActorConfig;
254 use std::fs;
255
256 fn default_config() -> PcActorCriticConfig {
257 PcActorCriticConfig {
258 actor: PcActorConfig {
259 input_size: 9,
260 hidden_layers: vec![LayerDef {
261 size: 18,
262 activation: Activation::Tanh,
263 }],
264 output_size: 9,
265 output_activation: Activation::Tanh,
266 alpha: 0.1,
267 tol: 0.01,
268 min_steps: 1,
269 max_steps: 20,
270 lr_weights: 0.01,
271 synchronous: true,
272 temperature: 1.0,
273 local_lambda: 1.0,
274 residual: false,
275 rezero_init: 0.001,
276 },
277 critic: MlpCriticConfig {
278 input_size: 27,
279 hidden_layers: vec![LayerDef {
280 size: 36,
281 activation: Activation::Tanh,
282 }],
283 output_activation: Activation::Linear,
284 lr: 0.005,
285 },
286 gamma: 0.95,
287 surprise_low: 0.02,
288 surprise_high: 0.15,
289 adaptive_surprise: false,
290 surprise_buffer_size: 100,
291 entropy_coeff: 0.01,
292 }
293 }
294
295 fn make_agent() -> PcActorCritic {
296 use crate::linalg::cpu::CpuLinAlg;
297 let agent: PcActorCritic =
298 PcActorCritic::new(CpuLinAlg::new(), default_config(), 42).unwrap();
299 agent
300 }
301
302 fn temp_path(name: &str) -> String {
303 let dir = std::env::temp_dir().join("pc_core_tests");
304 fs::create_dir_all(&dir).unwrap();
305 dir.join(name).to_string_lossy().to_string()
306 }
307
308 fn assert_vecs_approx_eq(a: &[f64], b: &[f64]) {
310 assert_eq!(
311 a.len(),
312 b.len(),
313 "Lengths differ: {} vs {}",
314 a.len(),
315 b.len()
316 );
317 for (i, (va, vb)) in a.iter().zip(b.iter()).enumerate() {
318 assert!((va - vb).abs() < 1e-15, "Element {i} differs: {va} vs {vb}");
319 }
320 }
321
322 #[test]
323 fn test_roundtrip_preserves_actor_weights() {
324 let agent = make_agent();
325 let path = temp_path("test_actor_roundtrip.json");
326 save_agent(&agent, &path, 10, None).unwrap();
327 let (loaded, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
328 for (orig, loaded_layer) in agent.actor.layers.iter().zip(loaded.actor.layers.iter()) {
329 assert_vecs_approx_eq(&orig.weights.data, &loaded_layer.weights.data);
330 assert_vecs_approx_eq(&orig.bias, &loaded_layer.bias);
331 }
332 let _ = fs::remove_file(&path);
333 }
334
335 #[test]
336 fn test_roundtrip_preserves_critic_weights() {
337 let agent = make_agent();
338 let path = temp_path("test_critic_roundtrip.json");
339 save_agent(&agent, &path, 10, None).unwrap();
340 let (loaded, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
341 for (orig, loaded_layer) in agent.critic.layers.iter().zip(loaded.critic.layers.iter()) {
342 assert_vecs_approx_eq(&orig.weights.data, &loaded_layer.weights.data);
343 assert_vecs_approx_eq(&orig.bias, &loaded_layer.bias);
344 }
345 let _ = fs::remove_file(&path);
346 }
347
348 #[test]
349 fn test_roundtrip_preserves_config() {
350 let agent = make_agent();
351 let path = temp_path("test_config_roundtrip.json");
352 save_agent(&agent, &path, 10, None).unwrap();
353 let (loaded, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
354 assert_eq!(loaded.config.gamma, agent.config.gamma);
355 assert_eq!(
356 loaded.config.actor.input_size,
357 agent.config.actor.input_size
358 );
359 assert_eq!(
360 loaded.config.critic.input_size,
361 agent.config.critic.input_size
362 );
363 assert_eq!(loaded.config.entropy_coeff, agent.config.entropy_coeff);
364 let _ = fs::remove_file(&path);
365 }
366
367 #[test]
368 fn test_metadata_includes_version_and_episode() {
369 let agent = make_agent();
370 let path = temp_path("test_metadata.json");
371 save_agent(&agent, &path, 42, None).unwrap();
372 let (_, metadata) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
373 assert!(!metadata.version.is_empty());
374 assert_eq!(metadata.episode, 42);
375 assert!(!metadata.created.is_empty());
376 let _ = fs::remove_file(&path);
377 }
378
379 #[test]
380 fn test_checkpoint_filename_no_colons() {
381 let name = checkpoint_filename(100);
382 assert!(!name.contains(':'), "Filename contains colons: {name}");
383 }
384
385 #[test]
386 fn test_checkpoint_filename_contains_episode_number() {
387 let name = checkpoint_filename(42);
388 assert!(
389 name.contains("ep42"),
390 "Filename doesn't contain episode number: {name}"
391 );
392 assert!(name.ends_with(".json"));
393 }
394
395 #[test]
396 fn test_load_nonexistent_returns_error() {
397 let result = load_agent(
398 "/nonexistent/path/agent.json",
399 crate::linalg::cpu::CpuLinAlg::new(),
400 );
401 assert!(result.is_err());
402 let err = result.err().unwrap();
403 assert!(
404 matches!(err, PcError::Io(_)),
405 "Expected PcError::Io, got: {err}"
406 );
407 }
408
409 #[test]
410 fn test_load_invalid_json_returns_error() {
411 let path = temp_path("test_invalid.json");
412 fs::write(&path, "not valid json {{{").unwrap();
413 let result = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new());
414 assert!(result.is_err());
415 let err = result.err().unwrap();
416 assert!(
417 matches!(err, PcError::Serialization(_)),
418 "Expected PcError::Serialization, got: {err}"
419 );
420 let _ = fs::remove_file(&path);
421 }
422
423 #[test]
424 fn test_load_topology_mismatch_returns_error() {
425 let agent = make_agent();
426 let path = temp_path("test_mismatch.json");
427 save_agent(&agent, &path, 0, None).unwrap();
428
429 let json = fs::read_to_string(&path).unwrap();
431 let mut save_file: SaveFile = serde_json::from_str(&json).unwrap();
432 save_file.config.actor.hidden_layers.push(LayerDef {
434 size: 10,
435 activation: Activation::Relu,
436 });
437 let tampered = serde_json::to_string_pretty(&save_file).unwrap();
438 fs::write(&path, tampered).unwrap();
439
440 let result = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new());
441 assert!(result.is_err());
442 let err = result.err().unwrap();
443 assert!(
444 matches!(err, PcError::DimensionMismatch { .. }),
445 "Expected PcError::DimensionMismatch, got: {err}"
446 );
447 let _ = fs::remove_file(&path);
448 }
449
450 #[test]
451 fn test_load_agent_uses_entropy_seed_not_fixed() {
452 let agent = make_agent();
453 let path = temp_path("test_seed_entropy.json");
454 save_agent(&agent, &path, 10, None).unwrap();
455
456 let (mut loaded1, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
457 let (mut loaded2, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
458
459 let input = vec![0.5; 9];
462 let valid: Vec<usize> = (0..9).collect();
463
464 let mut actions1 = Vec::new();
465 let mut actions2 = Vec::new();
466 for _ in 0..20 {
467 let (a1, _) = loaded1.act(&input, &valid, crate::pc_actor::SelectionMode::Training);
468 let (a2, _) = loaded2.act(&input, &valid, crate::pc_actor::SelectionMode::Training);
469 actions1.push(a1);
470 actions2.push(a2);
471 }
472
473 assert_ne!(
474 actions1, actions2,
475 "Two loaded agents should have different exploration due to entropy seeding"
476 );
477 let _ = fs::remove_file(&path);
478 }
479
480 #[test]
481 fn test_loaded_agent_produces_identical_inference() {
482 let agent = make_agent();
483 let path = temp_path("test_identical_infer.json");
484 save_agent(&agent, &path, 10, None).unwrap();
485 let (loaded, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
486
487 let input = vec![0.5, -0.5, 1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0];
488 let orig_result = agent.infer(&input);
489 let loaded_result = loaded.infer(&input);
490
491 assert_eq!(orig_result.y_conv.len(), loaded_result.y_conv.len());
493 for (a, b) in orig_result.y_conv.iter().zip(loaded_result.y_conv.iter()) {
494 assert!((a - b).abs() < 1e-12, "y_conv differs: {a} vs {b}");
495 }
496 for (a, b) in orig_result
498 .latent_concat
499 .iter()
500 .zip(loaded_result.latent_concat.iter())
501 {
502 assert!((a - b).abs() < 1e-12, "latent_concat differs: {a} vs {b}");
503 }
504 let _ = fs::remove_file(&path);
505 }
506
507 #[test]
508 fn test_save_creates_parent_directories() {
509 let dir = std::env::temp_dir()
510 .join("pc_core_tests")
511 .join("nested")
512 .join("deep");
513 let path = dir.join("agent.json").to_string_lossy().to_string();
514
515 let _ = fs::remove_dir_all(&dir);
517
518 let agent = make_agent();
519 save_agent(&agent, &path, 0, None).unwrap();
520 assert!(Path::new(&path).exists());
521
522 let _ = fs::remove_dir_all(std::env::temp_dir().join("pc_core_tests").join("nested"));
524 }
525
526 #[test]
527 fn test_roundtrip_preserves_modified_rezero_alpha() {
528 use crate::pc_actor::SelectionMode;
529 let config = PcActorCriticConfig {
530 actor: PcActorConfig {
531 residual: true,
532 rezero_init: 0.005,
533 hidden_layers: vec![
534 LayerDef {
535 size: 27,
536 activation: Activation::Tanh,
537 },
538 LayerDef {
539 size: 27,
540 activation: Activation::Tanh,
541 },
542 ],
543 ..default_config().actor
544 },
545 critic: MlpCriticConfig {
546 input_size: 63,
547 ..default_config().critic
548 },
549 ..default_config()
550 };
551 let mut agent: PcActorCritic =
552 PcActorCritic::new(crate::linalg::cpu::CpuLinAlg::new(), config, 42).unwrap();
553 let input = vec![0.5; 9];
555 let valid: Vec<usize> = (0..9).collect();
556 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
557 let trajectory = vec![crate::pc_actor_critic::TrajectoryStep {
558 input: input.clone(),
559 latent_concat: infer.latent_concat,
560 y_conv: infer.y_conv,
561 hidden_states: infer.hidden_states,
562 prediction_errors: infer.prediction_errors,
563 tanh_components: infer.tanh_components,
564 action,
565 valid_actions: valid,
566 reward: 1.0,
567 surprise_score: infer.surprise_score,
568 steps_used: infer.steps_used,
569 }];
570 agent.learn(&trajectory);
571 let alpha_after_train = agent.actor.rezero_alpha.clone();
572 assert_ne!(alpha_after_train, vec![0.005]);
574
575 let path = temp_path("test_rezero_roundtrip.json");
576 save_agent(&agent, &path, 10, None).unwrap();
577 let (loaded, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
578 assert_eq!(
579 loaded.actor.rezero_alpha, alpha_after_train,
580 "Loaded rezero_alpha should match trained value, not rezero_init"
581 );
582 let _ = fs::remove_file(&path);
583 }
584
585 #[test]
586 fn test_roundtrip_non_residual_backward_compat() {
587 let agent = make_agent();
588 assert!(agent.actor.rezero_alpha.is_empty());
589
590 let path = temp_path("test_nonresidual_compat.json");
591 save_agent(&agent, &path, 10, None).unwrap();
592 let (loaded, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
593 assert!(loaded.actor.rezero_alpha.is_empty());
594 let _ = fs::remove_file(&path);
595 }
596
597 #[test]
598 fn test_load_agent_generic_matches_load_agent() {
599 let agent = make_agent();
600 let path = temp_path("test_generic_load.json");
601 save_agent(&agent, &path, 10, None).unwrap();
602
603 let (loaded_default, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
604 let (loaded_generic, _) = load_agent_generic::<crate::linalg::cpu::CpuLinAlg>(
605 &path,
606 crate::linalg::cpu::CpuLinAlg::new(),
607 )
608 .unwrap();
609
610 let input = vec![0.5, -0.5, 1.0, -1.0, 0.0, 0.5, -0.5, 1.0, -1.0];
611 let r1 = loaded_default.infer(&input);
612 let r2 = loaded_generic.infer(&input);
613
614 for (a, b) in r1.y_conv.iter().zip(r2.y_conv.iter()) {
615 assert!((a - b).abs() < 1e-15, "y_conv differs: {a} vs {b}");
616 }
617 let _ = fs::remove_file(&path);
618 }
619
620 #[test]
621 fn test_roundtrip_preserves_skip_projections_directly() {
622 use crate::pc_actor::SelectionMode;
623 let config = PcActorCriticConfig {
624 actor: PcActorConfig {
625 residual: true,
626 rezero_init: 0.005,
627 hidden_layers: vec![
628 LayerDef {
629 size: 27,
630 activation: Activation::Tanh,
631 },
632 LayerDef {
633 size: 18,
634 activation: Activation::Tanh,
635 },
636 ],
637 ..default_config().actor
638 },
639 critic: MlpCriticConfig {
640 input_size: 54,
641 ..default_config().critic
642 },
643 ..default_config()
644 };
645 let mut agent: PcActorCritic =
646 PcActorCritic::new(crate::linalg::cpu::CpuLinAlg::new(), config, 42).unwrap();
647 let input = vec![0.5; 9];
649 let valid: Vec<usize> = (0..9).collect();
650 let (action, infer) = agent.act(&input, &valid, SelectionMode::Training);
651 let trajectory = vec![crate::pc_actor_critic::TrajectoryStep {
652 input: input.clone(),
653 latent_concat: infer.latent_concat,
654 y_conv: infer.y_conv,
655 hidden_states: infer.hidden_states,
656 prediction_errors: infer.prediction_errors,
657 tanh_components: infer.tanh_components,
658 action,
659 valid_actions: valid,
660 reward: 1.0,
661 surprise_score: infer.surprise_score,
662 steps_used: infer.steps_used,
663 }];
664 agent.learn(&trajectory);
665
666 assert!(agent.actor.skip_projections[0].is_some());
668 let orig_proj = agent.actor.skip_projections[0].as_ref().unwrap();
669 let orig_data = orig_proj.data.clone();
670
671 let path = temp_path("test_skip_proj_roundtrip.json");
672 save_agent(&agent, &path, 10, None).unwrap();
673 let (loaded, _) = load_agent(&path, crate::linalg::cpu::CpuLinAlg::new()).unwrap();
674
675 let loaded_proj = loaded.actor.skip_projections[0].as_ref().unwrap();
676 assert_eq!(orig_data.len(), loaded_proj.data.len());
677 for (i, (a, b)) in orig_data.iter().zip(loaded_proj.data.iter()).enumerate() {
678 assert!(
679 (a - b).abs() < 1e-15,
680 "skip_projection element {i} differs: {a} vs {b}"
681 );
682 }
683 let _ = fs::remove_file(&path);
684 }
685
686 #[test]
688 fn test_v1_fixture_loads_in_v2() {
689 let backend = crate::linalg::cpu::CpuLinAlg::new();
690 let (mut agent, metadata) = load_agent("tests/fixtures/v1_model.json", backend).unwrap();
691 assert!(!metadata.version.is_empty());
693 assert_eq!(metadata.episode, 100);
694 let state = vec![0.5; 9];
696 let valid: Vec<usize> = (0..9).collect();
697 let (action, _) = agent.act(&state, &valid, crate::pc_actor::SelectionMode::Play);
698 assert!(action < 9, "Action must be in valid range");
699 }
700}