1use crate::{
6 Config, ArticleExtractionEnvironment, BaselineExtractor, ExtractionError,
7 agents::{AgentFactory, RLAgent},
8};
9
10use crate::{
11 replay_buffer::PrioritizedReplayBuffer,
12 SiteProfileMemory,
13 reward::ImprovedRewardCalculator,
14 curriculum::CurriculumManager,
15 Result,
16};
17
18use crate::environment::StepInfo;
19use indicatif::{ProgressBar, ProgressStyle};
20use std::path::Path;
21use tracing::{info, warn};
22use crate::{Checkpoint, CheckpointManager};
23use candle_nn::{VarMap};
24use candle_core::Device;
25
26
27fn extract_domain_from_url(url: &str) -> String {
30 use url::Url;
31
32 match Url::parse(url) {
34 Ok(parsed_url) => {
35 parsed_url.host_str()
36 .map(|h| h.to_string())
37 .unwrap_or_else(|| "unknown".to_string())
38 }
39 Err(_) => {
40 let url = url.trim();
42 let without_protocol = url.strip_prefix("https://")
43 .or_else(|| url.strip_prefix("http://"))
44 .unwrap_or(url);
45
46 let host_part = without_protocol.split('/').next().unwrap_or("");
48
49 let domain = host_part.split(':').next().unwrap_or("");
51
52 if domain.is_empty() {
53 "unknown".to_string()
54 } else {
55 domain.to_string()
56 }
57 }
58 }
59}
60
61
62#[derive(Debug, Clone, Default)] pub struct TrainingMetrics {
65 pub episode_rewards: Vec<f32>,
66 pub episode_qualities: Vec<f32>,
67 pub episode_losses: Vec<f32>,
68 pub best_avg_quality: f32,
69}
70
71
72pub fn train_standard(
74 config: &Config,
75 html_samples: Vec<(String, String)>,
76) -> Result<(Box<dyn RLAgent>, TrainingMetrics)> {
77 info!("Starting standard training for {} episodes", config.num_episodes);
78
79 let device = if config.use_cpu_for_tuning {
80 Device::Cpu
81 } else if crate::cuda_is_available() {
82 Device::cuda_if_available(0).unwrap_or(Device::Cpu)
83 } else {
84 Device::Cpu
85 };
86
87 let _varmap = VarMap::new();
88
89 let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
91 let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
92 let mut replay_buffer = PrioritizedReplayBuffer::new(
93 config.replay_buffer_size,
94 config.priority_alpha,
95 config.priority_beta,
96 );
97
98 let mut agent = AgentFactory::create(
99 config.algorithm,
100 config.state_dim,
101 config.num_discrete_actions,
102 config.num_continuous_params,
103 config.gamma as f32,
104 config.learning_rate,
105 &device,
106 )?;
107
108 let mut env = ArticleExtractionEnvironment::new(baseline_extractor, config.clone());
117 let mut metrics = TrainingMetrics::default();
118 let mut epsilon = config.epsilon_start;
119
120 let checkpoint_dir = config.models_dir.join("checkpoints");
122 let checkpoint_manager = CheckpointManager::new(checkpoint_dir, 5)?;
123
124 let start_episode = match checkpoint_manager.load_latest() {
126 Ok(Some(checkpoint)) => {
127 if checkpoint.episode >= config.num_episodes {
129 warn!(
130 "Found checkpoint at episode {} but current run is only {} episodes. Starting fresh.",
131 checkpoint.episode, config.num_episodes
132 );
133 0
134 } else if !checkpoint.model_path.exists() {
135 warn!(
136 "Checkpoint references missing model file: {}. Starting fresh.",
137 checkpoint.model_path.display()
138 );
139 0
140 } else {
141 info!("Found checkpoint at episode {}, attempting to load...", checkpoint.episode);
143
144 match AgentFactory::load(
145 &checkpoint.model_path,
146 config.state_dim,
147 config.num_discrete_actions,
148 config.num_continuous_params,
149 &device,
150 ) {
151 Ok(loaded_agent) => {
152 agent = loaded_agent;
153 epsilon = checkpoint.epsilon as f64;
154 metrics.best_avg_quality = checkpoint.best_quality;
155 info!("Successfully resumed from checkpoint at episode {}", checkpoint.episode);
156 checkpoint.episode
157 }
158 Err(e) => {
159 warn!("Failed to load checkpoint model: {}. Starting fresh.", e);
160 warn!("Consider deleting checkpoint directory if corruption persists.");
161 0
162 }
163 }
164 }
165 }
166 Ok(None) => {
167 info!("No checkpoint found, starting fresh training");
168 0
169 }
170 Err(e) => {
171 warn!("Error loading checkpoint: {}. Starting fresh.", e);
172 0
173 }
174 };
175
176 let pb = ProgressBar::new((config.num_episodes - start_episode) as u64);
178 pb.set_style(
179 ProgressStyle::default_bar()
180 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
181 .unwrap()
182 .progress_chars("=>-"),
183 );
184
185 for episode in start_episode..config.num_episodes {
186 let idx = episode % html_samples.len();
188 let (html, url) = &html_samples[idx];
189
190 let domain = extract_domain_from_url(url);
191
192 let site_profile = site_memory.get_profile(&domain);
193
194 let mut state = env.reset(html, url.clone(), Some(site_profile))?;
196
197 let mut episode_reward = 0.0;
198 let mut done = false;
199 let mut step_info = StepInfo {
200 quality_score: 0.0,
201 text: String::new(),
202 xpath: String::new(),
203 parameters: std::collections::HashMap::new(),
204 step_count: 0,
205 };
206
207 while !done {
209 let action = agent.select_action(&state, epsilon as f32)?;
210 let (next_state, reward, is_done, info) = env.step(action.clone())?;
211
212 episode_reward += reward;
213 done = is_done;
214 step_info = info;
215
216 let experience = crate::replay_buffer::Experience {
218 state: state.clone(),
219 action,
220 reward,
221 next_state: next_state.clone(),
222 done,
223 };
224 replay_buffer.add(experience);
225
226 if replay_buffer.len() > config.batch_size * 10 {
228 let loss = agent.train_step(&mut replay_buffer, config.batch_size)?;
229 metrics.episode_losses.push(loss);
230 }
231
232 state = next_state;
233 }
234
235 let profile = site_memory.get_profile(&domain);
237 let extraction_result = crate::site_profile::ExtractionResult {
238 text: step_info.text.clone(),
239 xpath: step_info.xpath.clone(),
240 quality_score: step_info.quality_score,
241 parameters: step_info.parameters.clone(),
242 title: None,
243 date: None,
244 };
245 profile.add_extraction(extraction_result);
246
247 epsilon *= config.epsilon_decay;
249 epsilon = epsilon.max(config.epsilon_end);
250
251 if episode % config.target_update_freq == 0 {
253 agent.update_target_network();
254 }
255
256 metrics.episode_rewards.push(episode_reward);
258 metrics.episode_qualities.push(step_info.quality_score);
259
260 if episode % 10 == 0 {
262 let avg_reward = if metrics.episode_rewards.len() >= 100 {
263 metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
264 .iter()
265 .sum::<f32>() / 100.0
266 } else {
267 episode_reward
268 };
269
270 pb.set_message(format!(
271 "Reward: {:.3}, Quality: {:.3}",
272 avg_reward, step_info.quality_score
273 ));
274 }
275 pb.inc(1);
276
277 if episode % config.checkpoint_freq == 0 && episode > 0 && config.num_episodes >= 5000 {
279 let checkpoint_path = config.models_dir.join(format!(
280 "checkpoint_{}_{}_ep{}.onnx",
281 config.algorithm.to_string().to_lowercase(),
282 chrono::Utc::now().format("%Y%m%d_%H%M%S"),
283 episode
284 ));
285
286 match agent.save(&checkpoint_path) {
288 Ok(_) => {
289 let avg_reward = if metrics.episode_rewards.len() >= 100 {
290 metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
291 .iter()
292 .sum::<f32>() / 100.0
293 } else {
294 0.0
295 };
296
297 let avg_quality = if metrics.episode_qualities.len() >= 100 {
298 metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
299 .iter()
300 .sum::<f32>() / 100.0
301 } else {
302 0.0
303 };
304
305 let checkpoint = Checkpoint::new(
306 episode,
307 agent.get_step_count(),
308 avg_reward,
309 avg_quality,
310 metrics.best_avg_quality,
311 epsilon as f32,
312 checkpoint_path.clone(),
313 );
314
315 match checkpoint_manager.save_checkpoint(&checkpoint) {
316 Ok(_) => {
317 if checkpoint_path.exists() {
318 let metadata = std::fs::metadata(&checkpoint_path)?;
319 if metadata.len() > 0 {
320 site_memory.save_all()?;
321 info!("Checkpoint saved at episode {} ({} bytes)", episode, metadata.len());
322 } else {
323 warn!("Checkpoint file is empty, may be corrupted");
324 }
325 } else {
326 warn!("Checkpoint file disappeared after save");
327 }
328 }
329 Err(e) => {
330 warn!("Failed to save checkpoint metadata: {}", e);
331 }
332 }
333 }
334 Err(e) => {
335 warn!("Failed to save model checkpoint: {}", e);
336 }
337 }
338 }
339 }
340
341 pb.finish_with_message("Training completed");
342
343 let final_path = config.models_dir.join(format!(
345 "final_model_{}.onnx",
346 config.algorithm.to_string().to_lowercase()
347 ));
348
349 let mut hyperparams = std::collections::HashMap::new();
350 hyperparams.insert("learning_rate".to_string(), config.learning_rate);
351 hyperparams.insert("batch_size".to_string(), config.batch_size as f64);
352 hyperparams.insert("gamma".to_string(), config.gamma);
353 hyperparams.insert("epsilon_decay".to_string(), config.epsilon_decay);
354 hyperparams.insert("target_update_freq".to_string(), config.target_update_freq as f64);
355 agent.save_with_metadata(&final_path, config.num_episodes, hyperparams)?;
356
357 if final_path.exists() {
359 let metadata = std::fs::metadata(&final_path)?;
360 info!("Final model saved: {} bytes", metadata.len());
361 }
362
363 if let Ok(model_meta) = crate::models::ModelMetadata::load_metadata(&final_path) {
365 model_meta.display();
366 }
367
368 site_memory.save_all()?;
369
370 let final_checkpoint = Checkpoint::new(
372 config.num_episodes,
373 agent.get_step_count(),
374 metrics.episode_rewards.last().copied().unwrap_or(0.0),
375 metrics.episode_qualities.last().copied().unwrap_or(0.0),
376 metrics.best_avg_quality,
377 epsilon as f32,
378 final_path,
379 );
380 checkpoint_manager.save_checkpoint(&final_checkpoint)?;
381
382 info!("Training completed. Best avg quality: {:.3}", metrics.best_avg_quality);
383
384 Ok((agent, metrics))
385}
386
387
388pub fn train_with_improvements(
390 config: &Config,
391 html_samples: Vec<(String, String)>,
392) -> Result<(Box<dyn RLAgent>, TrainingMetrics)> {
393 info!("Starting OPTIMIZED training for {} episodes", config.num_episodes);
394 info!("Performance settings:");
395 info!(" - Batch size: {}", config.batch_size);
396 info!(" - Train frequency: every {} steps", config.train_freq);
397 info!(" - Gradient updates per episode: {}", config.num_train_steps_per_episode);
398 info!(" - Min replay size: {}", config.min_replay_size);
399 info!(" - Metrics window: {}", config.metrics_window);
400 info!(" - Dataset size: {}", html_samples.len());
401
402 let device = if config.use_cpu_for_tuning {
404 Device::Cpu } else if crate::cuda_is_available() {
406 Device::cuda_if_available(0).unwrap_or(Device::Cpu)
407 } else {
408 Device::Cpu
409 };
410
411 let mut global_step:usize = 0;
413 let mut total_training_steps:usize = 0;
414
415 let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
417 let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
418 let mut replay_buffer = PrioritizedReplayBuffer::new(
419 config.replay_buffer_size,
420 config.priority_alpha,
421 config.priority_beta,
422 );
423
424 let mut agent = AgentFactory::create(
426 config.algorithm,
427 config.state_dim,
428 config.num_discrete_actions,
429 config.num_continuous_params,
430 config.gamma as f32,
431 config.learning_rate,
432 &device,
433 )?;
434
435 let mut env = ArticleExtractionEnvironment::new(baseline_extractor.clone(), config.clone());
436 let mut metrics = TrainingMetrics { episode_rewards: vec![], episode_qualities: vec![], episode_losses: vec![], best_avg_quality: 0.0 };
437
438 let reward_calculator = ImprovedRewardCalculator::new(config.stopwords.clone());
440 let mut curriculum = CurriculumManager::new();
441 let mut epsilon = config.epsilon_start;
442
443 let checkpoint_dir = config.models_dir.join("checkpoints");
445 let checkpoint_manager = CheckpointManager::new(checkpoint_dir, 5)?;
446
447 let start_episode = match checkpoint_manager.load_latest() {
449 Ok(Some(checkpoint)) => {
450 if checkpoint.episode >= config.num_episodes {
452 warn!(
453 "Found checkpoint at episode {} but current run is only {} episodes. Starting fresh.",
454 checkpoint.episode, config.num_episodes
455 );
456 0
457 } else if !checkpoint.model_path.exists() {
458 warn!(
459 "Checkpoint references missing model file: {}. Starting fresh.",
460 checkpoint.model_path.display()
461 );
462 0
463 } else {
464 info!("Found checkpoint at episode {}, attempting to load...", checkpoint.episode);
466
467 match AgentFactory::load(
468 &checkpoint.model_path,
469 config.state_dim,
470 config.num_discrete_actions,
471 config.num_continuous_params,
472 &device,
473 ) {
474 Ok(loaded_agent) => {
475 agent = loaded_agent;
476 epsilon = checkpoint.epsilon as f64;
477 metrics.best_avg_quality = checkpoint.best_quality;
478
479 let step_counts_path = checkpoint.model_path.with_extension("steps.json");
481 if step_counts_path.exists() {
482 if let Ok(step_data) = std::fs::read_to_string(&step_counts_path) {
483 if let Ok(step_counts) = serde_json::from_str::<(usize, usize)>(&step_data) {
484 global_step = step_counts.0;
485 total_training_steps = step_counts.1;
486 info!("Resumed step counts: global_step={}, total_training_steps={}",
487 global_step, total_training_steps);
488 }
489 }
490 }
491
492 info!("Successfully resumed from checkpoint at episode {}", checkpoint.episode);
493 checkpoint.episode
494 }
495 Err(e) => {
496 warn!("Failed to load checkpoint model: {}. Starting fresh.", e);
497 warn!("Consider deleting checkpoint directory if corruption persists.");
498 0
499 }
500 }
501 }
502 }
503 Ok(None) => {
504 info!("No checkpoint found, starting fresh training");
505 0
506 }
507 Err(e) => {
508 warn!("Error loading checkpoint: {}. Starting fresh.", e);
509 0
510 }
511 };
512
513 let pb = ProgressBar::new((config.num_episodes - start_episode) as u64);
515 pb.set_style(
516 ProgressStyle::default_bar()
517 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
518 .unwrap()
519 .progress_chars("█▓▒░"),
520 );
521
522 for episode in start_episode..config.num_episodes {
523 let mut _episode_training_steps:usize = 0;
524 curriculum.update_threshold(episode);
526
527 let appropriate_samples: Vec<_> = html_samples.iter()
529 .filter(|(html, _)| curriculum.is_appropriate(html))
530 .collect();
531
532 if appropriate_samples.is_empty() {
533 warn!("No appropriate HTML samples for current curriculum");
534 break;
535 }
536
537 let idx = episode % appropriate_samples.len();
538 let (html, file_path) = appropriate_samples[idx];
539
540 let domain = extract_domain_from_url(file_path);
542
543 if episode < 10 {
545 info!("Episode {}: File: {}, Domain: {}", episode, file_path, domain);
546 }
547
548 let site_profile = site_memory.get_profile(&domain);
549
550 let baseline_result = baseline_extractor.extract(html)?;
552 let baseline_score = baseline_result.quality_score;
553
554 let mut state = env.reset(html, file_path.clone(), Some(site_profile))?;
556
557 let mut episode_reward = 0.0;
558 let mut done = false;
559 let mut step_info = StepInfo {
560 quality_score: 0.0,
561 text: String::new(),
562 xpath: String::new(),
563 parameters: std::collections::HashMap::new(),
564 step_count: 0,
565 };
566
567 while !done {
569 let action = agent.select_action(&state, epsilon as f32)?;
570 let (next_state, _, is_done, info) = env.step(action.clone())?;
571
572 let reward = reward_calculator.calculate_reward(&info.text, baseline_score);
574
575 episode_reward += reward;
576 done = is_done;
577 step_info = info;
578 global_step += 1;
580
581 let experience = crate::replay_buffer::Experience {
583 state: state.clone(),
584 action,
585 reward,
586 next_state: next_state.clone(),
587 done,
588 };
589 replay_buffer.add(experience);
590
591 if replay_buffer.len() >= config.min_replay_size &&
593 global_step.is_multiple_of(config.train_freq) {
594 match agent.train_step(&mut replay_buffer, config.batch_size) {
596 Ok(loss) => {
597 if loss.is_nan() || loss.is_infinite() {
599 warn!("Invalid loss detected at episode {}, step {}: {}", episode, global_step, loss);
600 warn!("Skipping this training step");
601 } else {
602 metrics.episode_losses.push(loss);
603 _episode_training_steps += 1;
604 }
605 }
606 Err(e) => {
607 warn!("Training step failed at episode {}, step {}: {}", episode, global_step, e);
608 warn!("Continuing training...");
609 }
611 }
612 }
613
614 state = next_state;
615 }
616
617 if replay_buffer.len() >= config.min_replay_size {
619 for update_idx in 0..config.num_train_steps_per_episode {
620 match agent.train_step(&mut replay_buffer, config.batch_size) {
621 Ok(loss) => {
622 if loss.is_nan() || loss.is_infinite() {
623 warn!("Invalid loss in gradient update {} at episode {}", update_idx, episode);
624 break; }
626 metrics.episode_losses.push(loss);
627 total_training_steps += 1;
628 }
629 Err(e) => {
630 warn!("Gradient update {} failed at episode {}: {}", update_idx, episode, e);
631 break; }
633 }
634 }
635 }
636
637 let profile = site_memory.get_profile(&domain);
639 let extraction_result = crate::site_profile::ExtractionResult {
640 text: step_info.text.clone(),
641 xpath: step_info.xpath.clone(),
642 quality_score: step_info.quality_score,
643 parameters: step_info.parameters.clone(),
644 title: None,
645 date: None,
646 };
647 profile.add_extraction(extraction_result);
648 if episode % 100 == 0 && episode > 0 {
650 match site_memory.save_all() {
651 Ok(_) => {
652 if episode % 500 == 0 {
653 info!("Site profiles saved at episode {}", episode);
654 }
655 }
656 Err(e) => {
657 warn!("Failed to save site profiles: {}", e);
658 }
659 }
660 }
661
662 let progress = (episode as f32 / 2000.0).min(1.0);
664 epsilon = config.epsilon_start * (config.epsilon_end / config.epsilon_start).powf(progress as f64);
665 epsilon = epsilon.max(config.epsilon_end);
666
667 if episode % config.target_update_freq == 0 {
669 agent.update_target_network();
670 }
671
672 metrics.episode_rewards.push(episode_reward);
674 metrics.episode_qualities.push(step_info.quality_score);
675
676 if episode % config.log_freq == 0 {
678 let window = config.metrics_window;
679 let avg_reward = if metrics.episode_rewards.len() >= window {
680 metrics.episode_rewards[metrics.episode_rewards.len() - window..]
681 .iter()
682 .sum::<f32>() / window as f32
683 } else if !metrics.episode_rewards.is_empty() {
684 metrics.episode_rewards.iter().sum::<f32>() / metrics.episode_rewards.len() as f32
685 } else {
686 0.0
687 };
688
689 let avg_quality = if metrics.episode_qualities.len() >= window {
690 metrics.episode_qualities[metrics.episode_qualities.len() - window..]
691 .iter()
692 .sum::<f32>() / window as f32
693 } else if !metrics.episode_qualities.is_empty() {
694 metrics.episode_qualities.iter().sum::<f32>() / metrics.episode_qualities.len() as f32
695 } else {
696 0.0
697 };
698
699 let curriculum_threshold = curriculum.get_threshold();
700 pb.set_message(format!(
701 "R:{:.2} Q:{:.3} ε:{:.3} C:{:.2} Steps:{}",
702 avg_reward, avg_quality, epsilon, curriculum_threshold, total_training_steps
703 ));
704 }
705 pb.inc(1);
706
707 if episode % config.checkpoint_freq == 0 && episode > 0 {
709 let checkpoint_path = config.models_dir.join(format!(
710 "checkpoint_{}_{}_ep{}.onnx",
711 config.algorithm.to_string().to_lowercase(),
712 chrono::Utc::now().format("%Y%m%d_%H%M%S"),
713 episode
714 ));
715
716 match agent.save(&checkpoint_path) {
717 Ok(_) => {
718 let step_counts_path = checkpoint_path.with_extension("steps.json");
720 let step_counts = (global_step, total_training_steps);
721 if let Ok(step_data) = serde_json::to_string(&step_counts) {
722 let _ = std::fs::write(&step_counts_path, step_data);
723 }
724
725 let avg_reward = if metrics.episode_rewards.len() >= 100 {
726 metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
727 .iter()
728 .sum::<f32>() / 100.0
729 } else {
730 0.0
731 };
732
733 let avg_quality = if metrics.episode_qualities.len() >= 100 {
734 metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
735 .iter()
736 .sum::<f32>() / 100.0
737 } else {
738 0.0
739 };
740
741 let checkpoint = Checkpoint::new(
742 episode,
743 total_training_steps,
744 avg_reward,
745 avg_quality,
746 metrics.best_avg_quality,
747 epsilon as f32,
748 checkpoint_path.clone(),
749 );
750
751 match checkpoint_manager.save_checkpoint(&checkpoint) {
752 Ok(_) => {
753 site_memory.save_all()?;
754 info!("Improved training checkpoint saved at episode {} (global_step={})",
755 episode, global_step);
756 }
757 Err(e) => {
758 warn!("Failed to save checkpoint metadata: {}", e);
759 }
760 }
761
762 if let Ok(metadata) = std::fs::metadata(&checkpoint_path) {
763 let file_size = metadata.len();
764 if file_size < 10_000 {
765 warn!("Checkpoint file suspiciously small: {} bytes", file_size);
766 } else {
767 info!("Checkpoint saved at episode {} ({} bytes)", episode, file_size);
768 }
769 }
770 }
771 Err(e) => {
772 warn!("Failed to save model checkpoint: {}", e);
773 }
774 }
775 }
776
777 if metrics.episode_qualities.len() >= 100 {
779 let avg_quality = metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
780 .iter()
781 .sum::<f32>() / 100.0;
782
783 if avg_quality > metrics.best_avg_quality {
784 metrics.best_avg_quality = avg_quality;
785 let best_path = config.models_dir.join(format!(
786 "best_model_{}.onnx",
787 config.algorithm.to_string().to_lowercase()
788 ));
789
790 match agent.save(&best_path) {
791 Ok(_) => {
792 if let Ok(metadata) = std::fs::metadata(&best_path) {
793 info!("New best {} model saved with quality: {:.3} ({} bytes)",
794 config.algorithm, avg_quality, metadata.len());
795 } else {
796 info!("New best {} model saved with quality: {:.3}",
797 config.algorithm, avg_quality);
798 }
799 }
800 Err(e) => {
801 warn!("Failed to save best model: {}", e);
802 }
803 }
804 }
805 }
806 }
807
808 pb.finish_with_message("Improved training completed");
809
810 let final_path = config.models_dir.join(format!(
812 "final_model_{}.onnx",
813 config.algorithm.to_string().to_lowercase()
814 ));
815
816 let mut hyperparams = std::collections::HashMap::new();
817 hyperparams.insert("learning_rate".to_string(), config.learning_rate);
818 hyperparams.insert("batch_size".to_string(), config.batch_size as f64);
819 hyperparams.insert("gamma".to_string(), config.gamma);
820 hyperparams.insert("epsilon_decay".to_string(), config.epsilon_decay);
821 hyperparams.insert("target_update_freq".to_string(), config.target_update_freq as f64);
822 agent.save_with_metadata(&final_path, config.num_episodes, hyperparams)?;
823
824 if final_path.exists() {
826 let metadata = std::fs::metadata(&final_path)?;
827 info!("Final model saved: {} bytes", metadata.len());
828 }
829
830 if let Ok(model_meta) = crate::models::ModelMetadata::load_metadata(&final_path) {
832 model_meta.display();
833 }
834
835 site_memory.save_all()?;
836
837 let final_checkpoint = Checkpoint::new(
839 config.num_episodes,
840 total_training_steps,
841 metrics.episode_rewards.last().copied().unwrap_or(0.0),
842 metrics.episode_qualities.last().copied().unwrap_or(0.0),
843 metrics.best_avg_quality,
844 epsilon as f32,
845 final_path,
846 );
847 checkpoint_manager.save_checkpoint(&final_checkpoint)?;
848
849 info!("Training completed:");
850 info!(" - Total episodes: {}", config.num_episodes);
851 info!(" - Total training steps: {}", total_training_steps);
852 info!(" - Best avg quality: {:.3}", metrics.best_avg_quality);
853 info!(" - Final epsilon: {:.3}", epsilon);
854
855 Ok((agent, metrics))
856}
857
858pub fn save_training_plot(metrics: &TrainingMetrics, output_path: &Path) -> Result<()> {
860 use plotters::prelude::*;
861
862 let root = BitMapBackend::new(output_path, (1200, 800))
863 .into_drawing_area();
864 root.fill(&WHITE)
865 .map_err(|e| ExtractionError::ModelError(format!("Plot error: {}", e)))?;
866
867 let areas = root.split_evenly((2, 1));
868 let upper = &areas[0];
869 let lower = &areas[1];
870
871 let max_episodes = metrics.episode_rewards.len();
873 let max_reward = metrics.episode_rewards.iter()
874 .copied()
875 .fold(f32::NEG_INFINITY, f32::max);
876 let min_reward = metrics.episode_rewards.iter()
877 .copied()
878 .fold(f32::INFINITY, f32::min);
879
880 let mut chart = ChartBuilder::on(upper)
881 .caption("Episode Rewards", ("sans-serif", 30))
882 .margin(10)
883 .x_label_area_size(30)
884 .y_label_area_size(40)
885 .build_cartesian_2d(0..max_episodes, min_reward..max_reward)
886 .map_err(|e| ExtractionError::ModelError(format!("Chart error: {}", e)))?;
887
888 chart.configure_mesh()
889 .draw()
890 .map_err(|e| ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
891
892 chart.draw_series(LineSeries::new(
893 metrics.episode_rewards.iter().enumerate().map(|(i, &r)| (i, r)),
894 &BLUE,
895 ))
896 .map_err(|e| ExtractionError::ModelError(format!("Series error: {}", e)))?;
897
898 let max_quality = metrics.episode_qualities.iter()
900 .copied()
901 .fold(f32::NEG_INFINITY, f32::max);
902
903 let mut chart2 = ChartBuilder::on(lower)
904 .caption("Episode Quality", ("sans-serif", 30))
905 .margin(10)
906 .x_label_area_size(30)
907 .y_label_area_size(40)
908 .build_cartesian_2d(0..max_episodes, 0.0..max_quality)
909 .map_err(|e| ExtractionError::ModelError(format!("Chart error: {}", e)))?;
910
911 chart2.configure_mesh()
912 .draw()
913 .map_err(|e| ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
914
915 chart2.draw_series(LineSeries::new(
916 metrics.episode_qualities.iter().enumerate().map(|(i, &q)| (i, q)),
917 &GREEN,
918 ))
919 .map_err(|e| ExtractionError::ModelError(format!("Series error: {}", e)))?;
920
921 root.present().map_err(|e| crate::ExtractionError::IoError(
922 std::io::Error::other(e.to_string())
923 ))?;
924
925 info!("Training plot saved to: {}", output_path.display());
926
927 Ok(())
928}