1use crate::{
6 Config, ArticleExtractionEnvironment, BaselineExtractor, ExtractionError,
7 agents::{AgentFactory, RLAgent},
8};
9
10use crate::{
11 replay_buffer::PrioritizedReplayBuffer,
12 SiteProfileMemory,
13 curriculum::CurriculumManager,
14 Result,
15};
16
17use crate::environment::StepInfo;
18use rand::RngExt;
19use indicatif::{ProgressBar, ProgressStyle};
20use std::path::Path;
21use tracing::{info, warn};
22use crate::{Checkpoint, CheckpointManager};
23use candle_nn::{VarMap};
24use candle_core::Device;
25
26
27fn extract_domain_from_url(url: &str) -> String {
30 use url::Url;
31
32 match Url::parse(url) {
34 Ok(parsed_url) => {
35 parsed_url.host_str()
36 .map(|h| h.to_string())
37 .unwrap_or_else(|| "unknown".to_string())
38 }
39 Err(_) => {
40 let url = url.trim();
42 let without_protocol = url.strip_prefix("https://")
43 .or_else(|| url.strip_prefix("http://"))
44 .unwrap_or(url);
45
46 let host_part = without_protocol.split('/').next().unwrap_or("");
48
49 let domain = host_part.split(':').next().unwrap_or("");
51
52 if domain.is_empty() {
53 "unknown".to_string()
54 } else {
55 domain.to_string()
56 }
57 }
58 }
59}
60
61
62#[derive(Debug, Clone)]
69pub struct TrainingSample {
70 pub html: String,
71 pub url: String,
72 pub ground_truth_text: Option<String>,
73}
74
75impl TrainingSample {
76 pub fn with_ground_truth(html: String, url: String, ground_truth_text: String) -> Self {
78 Self { html, url, ground_truth_text: Some(ground_truth_text) }
79 }
80}
81
82impl From<(String, String)> for TrainingSample {
83 fn from((html, url): (String, String)) -> Self {
84 Self { html, url, ground_truth_text: None }
85 }
86}
87
88#[derive(Debug, Clone, Default)] pub struct TrainingMetrics {
91 pub episode_rewards: Vec<f32>,
92 pub episode_qualities: Vec<f32>,
93 pub episode_losses: Vec<f32>,
94 pub best_avg_quality: f32,
95}
96
97
98pub fn train_standard(
100 config: &Config,
101 html_samples: Vec<TrainingSample>,
102) -> Result<(Box<dyn RLAgent>, TrainingMetrics)> {
103 info!("Starting standard training for {} episodes", config.num_episodes);
104
105 let device = if config.use_cpu_for_tuning {
106 Device::Cpu
107 } else if crate::cuda_is_available() {
108 Device::cuda_if_available(0).unwrap_or(Device::Cpu)
109 } else {
110 Device::Cpu
111 };
112
113 let _varmap = VarMap::new();
114
115 let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
117 let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
118 let mut replay_buffer = PrioritizedReplayBuffer::new(
119 config.replay_buffer_size,
120 config.priority_alpha,
121 config.priority_beta,
122 );
123
124 let mut agent = AgentFactory::create(
125 config.algorithm,
126 config.state_dim,
127 config.num_discrete_actions,
128 config.num_continuous_params,
129 config.gamma as f32,
130 config.learning_rate,
131 &device,
132 )?;
133
134 let mut env = ArticleExtractionEnvironment::new(baseline_extractor, config.clone());
143 let mut metrics = TrainingMetrics::default();
144 let mut epsilon = config.epsilon_start;
145
146 let checkpoint_dir = config.models_dir.join("checkpoints");
148 let checkpoint_manager = CheckpointManager::new(checkpoint_dir, 5)?;
149
150 let start_episode = match checkpoint_manager.load_latest() {
152 Ok(Some(checkpoint)) => {
153 if checkpoint.episode >= config.num_episodes {
155 warn!(
156 "Found checkpoint at episode {} but current run is only {} episodes. Starting fresh.",
157 checkpoint.episode, config.num_episodes
158 );
159 0
160 } else if !checkpoint.model_path.exists() {
161 warn!(
162 "Checkpoint references missing model file: {}. Starting fresh.",
163 checkpoint.model_path.display()
164 );
165 0
166 } else {
167 info!("Found checkpoint at episode {}, attempting to load...", checkpoint.episode);
169
170 match AgentFactory::load(
171 &checkpoint.model_path,
172 config.state_dim,
173 config.num_discrete_actions,
174 config.num_continuous_params,
175 &device,
176 ) {
177 Ok(loaded_agent) => {
178 agent = loaded_agent;
179 epsilon = checkpoint.epsilon as f64;
180 metrics.best_avg_quality = checkpoint.best_quality;
181 info!("Successfully resumed from checkpoint at episode {}", checkpoint.episode);
182 checkpoint.episode
183 }
184 Err(e) => {
185 warn!("Failed to load checkpoint model: {}. Starting fresh.", e);
186 warn!("Consider deleting checkpoint directory if corruption persists.");
187 0
188 }
189 }
190 }
191 }
192 Ok(None) => {
193 info!("No checkpoint found, starting fresh training");
194 0
195 }
196 Err(e) => {
197 warn!("Error loading checkpoint: {}. Starting fresh.", e);
198 0
199 }
200 };
201
202 let pb = ProgressBar::new((config.num_episodes - start_episode) as u64);
204 pb.set_style(
205 ProgressStyle::default_bar()
206 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
207 .unwrap()
208 .progress_chars("=>-"),
209 );
210
211 for episode in start_episode..config.num_episodes {
212 let idx = rand::rng().random_range(0..html_samples.len());
215 let sample = &html_samples[idx];
216 let (html, url) = (&sample.html, &sample.url);
217
218 let domain = extract_domain_from_url(url);
219
220 let site_profile = site_memory.get_profile(&domain);
221
222 let mut state = env.reset(
224 html,
225 url.clone(),
226 sample.ground_truth_text.as_deref(),
227 Some(site_profile),
228 )?;
229
230 let mut episode_reward = 0.0;
231 let mut done = false;
232 let mut step_info = StepInfo {
233 quality_score: 0.0,
234 text: String::new(),
235 xpath: String::new(),
236 parameters: std::collections::HashMap::new(),
237 step_count: 0,
238 };
239
240 while !done {
242 let action = agent.select_action(&state, epsilon as f32)?;
243 let (next_state, reward, is_done, info) = env.step(action.clone())?;
244
245 episode_reward += reward;
246 done = is_done;
247 step_info = info;
248
249 let experience = crate::replay_buffer::Experience {
251 state: state.clone(),
252 action,
253 reward,
254 next_state: next_state.clone(),
255 done,
256 };
257 replay_buffer.add(experience);
258
259 if replay_buffer.len() > config.batch_size * 10 {
261 let loss = agent.train_step(&mut replay_buffer, config.batch_size)?;
262 metrics.episode_losses.push(loss);
263 }
264
265 state = next_state;
266 }
267
268 let profile = site_memory.get_profile(&domain);
270 let extraction_result = crate::site_profile::ExtractionResult {
271 text: step_info.text.clone(),
272 xpath: step_info.xpath.clone(),
273 quality_score: step_info.quality_score,
274 parameters: step_info.parameters.clone(),
275 title: None,
276 date: None,
277 };
278 profile.add_extraction(extraction_result);
279
280 epsilon *= config.epsilon_decay;
282 epsilon = epsilon.max(config.epsilon_end);
283
284 if episode % config.target_update_freq == 0 {
286 agent.update_target_network();
287 }
288
289 metrics.episode_rewards.push(episode_reward);
291 metrics.episode_qualities.push(step_info.quality_score);
292
293 if episode % 10 == 0 {
295 let avg_reward = if metrics.episode_rewards.len() >= 100 {
296 metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
297 .iter()
298 .sum::<f32>() / 100.0
299 } else {
300 episode_reward
301 };
302
303 pb.set_message(format!(
304 "Reward: {:.3}, Quality: {:.3}",
305 avg_reward, step_info.quality_score
306 ));
307 }
308 pb.inc(1);
309
310 if episode % config.checkpoint_freq == 0 && episode > 0 && config.num_episodes >= 5000 {
312 let checkpoint_path = config.models_dir.join(format!(
313 "checkpoint_{}_{}_ep{}.onnx",
314 config.algorithm.to_string().to_lowercase(),
315 chrono::Utc::now().format("%Y%m%d_%H%M%S"),
316 episode
317 ));
318
319 match agent.save(&checkpoint_path) {
321 Ok(_) => {
322 let avg_reward = if metrics.episode_rewards.len() >= 100 {
323 metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
324 .iter()
325 .sum::<f32>() / 100.0
326 } else {
327 0.0
328 };
329
330 let avg_quality = if metrics.episode_qualities.len() >= 100 {
331 metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
332 .iter()
333 .sum::<f32>() / 100.0
334 } else {
335 0.0
336 };
337
338 let checkpoint = Checkpoint::new(
339 episode,
340 agent.get_step_count(),
341 avg_reward,
342 avg_quality,
343 metrics.best_avg_quality,
344 epsilon as f32,
345 checkpoint_path.clone(),
346 );
347
348 match checkpoint_manager.save_checkpoint(&checkpoint) {
349 Ok(_) => {
350 if checkpoint_path.exists() {
351 let metadata = std::fs::metadata(&checkpoint_path)?;
352 if metadata.len() > 0 {
353 site_memory.save_all()?;
354 info!("Checkpoint saved at episode {} ({} bytes)", episode, metadata.len());
355 } else {
356 warn!("Checkpoint file is empty, may be corrupted");
357 }
358 } else {
359 warn!("Checkpoint file disappeared after save");
360 }
361 }
362 Err(e) => {
363 warn!("Failed to save checkpoint metadata: {}", e);
364 }
365 }
366 }
367 Err(e) => {
368 warn!("Failed to save model checkpoint: {}", e);
369 }
370 }
371 }
372 }
373
374 pb.finish_with_message("Training completed");
375
376 let final_path = config.models_dir.join(format!(
378 "final_model_{}.onnx",
379 config.algorithm.to_string().to_lowercase()
380 ));
381
382 let mut hyperparams = std::collections::HashMap::new();
383 hyperparams.insert("learning_rate".to_string(), config.learning_rate);
384 hyperparams.insert("batch_size".to_string(), config.batch_size as f64);
385 hyperparams.insert("gamma".to_string(), config.gamma);
386 hyperparams.insert("epsilon_decay".to_string(), config.epsilon_decay);
387 hyperparams.insert("target_update_freq".to_string(), config.target_update_freq as f64);
388 agent.save_with_metadata(&final_path, config.num_episodes, hyperparams)?;
389
390 if final_path.exists() {
392 let metadata = std::fs::metadata(&final_path)?;
393 info!("Final model saved: {} bytes", metadata.len());
394 }
395
396 if let Ok(model_meta) = crate::models::ModelMetadata::load_metadata(&final_path) {
398 model_meta.display();
399 }
400
401 site_memory.save_all()?;
402
403 let final_checkpoint = Checkpoint::new(
405 config.num_episodes,
406 agent.get_step_count(),
407 metrics.episode_rewards.last().copied().unwrap_or(0.0),
408 metrics.episode_qualities.last().copied().unwrap_or(0.0),
409 metrics.best_avg_quality,
410 epsilon as f32,
411 final_path,
412 );
413 checkpoint_manager.save_checkpoint(&final_checkpoint)?;
414
415 info!("Training completed. Best avg quality: {:.3}", metrics.best_avg_quality);
416
417 Ok((agent, metrics))
418}
419
420
421pub fn train_with_improvements(
423 config: &Config,
424 html_samples: Vec<TrainingSample>,
425) -> Result<(Box<dyn RLAgent>, TrainingMetrics)> {
426 info!("Starting OPTIMIZED training for {} episodes", config.num_episodes);
427 info!("Performance settings:");
428 info!(" - Batch size: {}", config.batch_size);
429 info!(" - Train frequency: every {} steps", config.train_freq);
430 info!(" - Gradient updates per episode: {}", config.num_train_steps_per_episode);
431 info!(" - Min replay size: {}", config.min_replay_size);
432 info!(" - Metrics window: {}", config.metrics_window);
433 info!(" - Dataset size: {}", html_samples.len());
434
435 let device = if config.use_cpu_for_tuning {
437 Device::Cpu } else if crate::cuda_is_available() {
439 Device::cuda_if_available(0).unwrap_or(Device::Cpu)
440 } else {
441 Device::Cpu
442 };
443
444 let mut global_step:usize = 0;
446 let mut total_training_steps:usize = 0;
447
448 let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
450 let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
451 let mut replay_buffer = PrioritizedReplayBuffer::new(
452 config.replay_buffer_size,
453 config.priority_alpha,
454 config.priority_beta,
455 );
456
457 let mut agent = AgentFactory::create(
459 config.algorithm,
460 config.state_dim,
461 config.num_discrete_actions,
462 config.num_continuous_params,
463 config.gamma as f32,
464 config.learning_rate,
465 &device,
466 )?;
467
468 let mut env = ArticleExtractionEnvironment::new(baseline_extractor.clone(), config.clone());
469 let mut metrics = TrainingMetrics { episode_rewards: vec![], episode_qualities: vec![], episode_losses: vec![], best_avg_quality: 0.0 };
470
471 let mut curriculum = CurriculumManager::new();
473 let mut epsilon = config.epsilon_start;
474
475 let checkpoint_dir = config.models_dir.join("checkpoints");
477 let checkpoint_manager = CheckpointManager::new(checkpoint_dir, 5)?;
478
479 let start_episode = match checkpoint_manager.load_latest() {
481 Ok(Some(checkpoint)) => {
482 if checkpoint.episode >= config.num_episodes {
484 warn!(
485 "Found checkpoint at episode {} but current run is only {} episodes. Starting fresh.",
486 checkpoint.episode, config.num_episodes
487 );
488 0
489 } else if !checkpoint.model_path.exists() {
490 warn!(
491 "Checkpoint references missing model file: {}. Starting fresh.",
492 checkpoint.model_path.display()
493 );
494 0
495 } else {
496 info!("Found checkpoint at episode {}, attempting to load...", checkpoint.episode);
498
499 match AgentFactory::load(
500 &checkpoint.model_path,
501 config.state_dim,
502 config.num_discrete_actions,
503 config.num_continuous_params,
504 &device,
505 ) {
506 Ok(loaded_agent) => {
507 agent = loaded_agent;
508 epsilon = checkpoint.epsilon as f64;
509 metrics.best_avg_quality = checkpoint.best_quality;
510
511 let step_counts_path = checkpoint.model_path.with_extension("steps.json");
513 if step_counts_path.exists() {
514 if let Ok(step_data) = std::fs::read_to_string(&step_counts_path) {
515 if let Ok(step_counts) = serde_json::from_str::<(usize, usize)>(&step_data) {
516 global_step = step_counts.0;
517 total_training_steps = step_counts.1;
518 info!("Resumed step counts: global_step={}, total_training_steps={}",
519 global_step, total_training_steps);
520 }
521 }
522 }
523
524 info!("Successfully resumed from checkpoint at episode {}", checkpoint.episode);
525 checkpoint.episode
526 }
527 Err(e) => {
528 warn!("Failed to load checkpoint model: {}. Starting fresh.", e);
529 warn!("Consider deleting checkpoint directory if corruption persists.");
530 0
531 }
532 }
533 }
534 }
535 Ok(None) => {
536 info!("No checkpoint found, starting fresh training");
537 0
538 }
539 Err(e) => {
540 warn!("Error loading checkpoint: {}. Starting fresh.", e);
541 0
542 }
543 };
544
545 let pb = ProgressBar::new((config.num_episodes - start_episode) as u64);
547 pb.set_style(
548 ProgressStyle::default_bar()
549 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
550 .unwrap()
551 .progress_chars("█▓▒░"),
552 );
553
554 for episode in start_episode..config.num_episodes {
555 let mut _episode_training_steps:usize = 0;
556 curriculum.update_threshold(episode);
558
559 let appropriate_samples: Vec<_> = html_samples.iter()
561 .filter(|s| curriculum.is_appropriate(&s.html))
562 .collect();
563
564 if appropriate_samples.is_empty() {
565 warn!("No appropriate HTML samples for current curriculum");
566 break;
567 }
568
569 let idx = rand::rng().random_range(0..appropriate_samples.len());
570 let sample = appropriate_samples[idx];
571 let html = &sample.html;
572 let file_path = &sample.url;
573
574 let domain = extract_domain_from_url(file_path);
576
577 if episode < 10 {
579 info!("Episode {}: File: {}, Domain: {}", episode, file_path, domain);
580 }
581
582 let site_profile = site_memory.get_profile(&domain);
583
584 let mut state = env.reset(
587 html,
588 file_path.clone(),
589 sample.ground_truth_text.as_deref(),
590 Some(site_profile),
591 )?;
592
593 let mut episode_reward = 0.0;
594 let mut done = false;
595 let mut step_info = StepInfo {
596 quality_score: 0.0,
597 text: String::new(),
598 xpath: String::new(),
599 parameters: std::collections::HashMap::new(),
600 step_count: 0,
601 };
602
603 while !done {
605 let action = agent.select_action(&state, epsilon as f32)?;
606 let (next_state, reward, is_done, info) = env.step(action.clone())?;
610
611 episode_reward += reward;
612 done = is_done;
613 step_info = info;
614 global_step += 1;
616
617 let experience = crate::replay_buffer::Experience {
619 state: state.clone(),
620 action,
621 reward,
622 next_state: next_state.clone(),
623 done,
624 };
625 replay_buffer.add(experience);
626
627 if replay_buffer.len() >= config.min_replay_size &&
629 global_step.is_multiple_of(config.train_freq) {
630 match agent.train_step(&mut replay_buffer, config.batch_size) {
632 Ok(loss) => {
633 if loss.is_nan() || loss.is_infinite() {
635 warn!("Invalid loss detected at episode {}, step {}: {}", episode, global_step, loss);
636 warn!("Skipping this training step");
637 } else {
638 metrics.episode_losses.push(loss);
639 _episode_training_steps += 1;
640 }
641 }
642 Err(e) => {
643 warn!("Training step failed at episode {}, step {}: {}", episode, global_step, e);
644 warn!("Continuing training...");
645 }
647 }
648 }
649
650 state = next_state;
651 }
652
653 if replay_buffer.len() >= config.min_replay_size {
655 for update_idx in 0..config.num_train_steps_per_episode {
656 match agent.train_step(&mut replay_buffer, config.batch_size) {
657 Ok(loss) => {
658 if loss.is_nan() || loss.is_infinite() {
659 warn!("Invalid loss in gradient update {} at episode {}", update_idx, episode);
660 break; }
662 metrics.episode_losses.push(loss);
663 total_training_steps += 1;
664 }
665 Err(e) => {
666 warn!("Gradient update {} failed at episode {}: {}", update_idx, episode, e);
667 break; }
669 }
670 }
671 }
672
673 let profile = site_memory.get_profile(&domain);
675 let extraction_result = crate::site_profile::ExtractionResult {
676 text: step_info.text.clone(),
677 xpath: step_info.xpath.clone(),
678 quality_score: step_info.quality_score,
679 parameters: step_info.parameters.clone(),
680 title: None,
681 date: None,
682 };
683 profile.add_extraction(extraction_result);
684 if episode % 100 == 0 && episode > 0 {
686 match site_memory.save_all() {
687 Ok(_) => {
688 if episode % 500 == 0 {
689 info!("Site profiles saved at episode {}", episode);
690 }
691 }
692 Err(e) => {
693 warn!("Failed to save site profiles: {}", e);
694 }
695 }
696 }
697
698 let progress = (episode as f32 / 2000.0).min(1.0);
700 epsilon = config.epsilon_start * (config.epsilon_end / config.epsilon_start).powf(progress as f64);
701 epsilon = epsilon.max(config.epsilon_end);
702
703 if episode % config.target_update_freq == 0 {
705 agent.update_target_network();
706 }
707
708 metrics.episode_rewards.push(episode_reward);
710 metrics.episode_qualities.push(step_info.quality_score);
711
712 if episode % config.log_freq == 0 {
714 let window = config.metrics_window;
715 let avg_reward = if metrics.episode_rewards.len() >= window {
716 metrics.episode_rewards[metrics.episode_rewards.len() - window..]
717 .iter()
718 .sum::<f32>() / window as f32
719 } else if !metrics.episode_rewards.is_empty() {
720 metrics.episode_rewards.iter().sum::<f32>() / metrics.episode_rewards.len() as f32
721 } else {
722 0.0
723 };
724
725 let avg_quality = if metrics.episode_qualities.len() >= window {
726 metrics.episode_qualities[metrics.episode_qualities.len() - window..]
727 .iter()
728 .sum::<f32>() / window as f32
729 } else if !metrics.episode_qualities.is_empty() {
730 metrics.episode_qualities.iter().sum::<f32>() / metrics.episode_qualities.len() as f32
731 } else {
732 0.0
733 };
734
735 let curriculum_threshold = curriculum.get_threshold();
736 pb.set_message(format!(
737 "R:{:.2} Q:{:.3} ε:{:.3} C:{:.2} Steps:{}",
738 avg_reward, avg_quality, epsilon, curriculum_threshold, total_training_steps
739 ));
740 }
741 pb.inc(1);
742
743 if episode % config.checkpoint_freq == 0 && episode > 0 {
745 let checkpoint_path = config.models_dir.join(format!(
746 "checkpoint_{}_{}_ep{}.onnx",
747 config.algorithm.to_string().to_lowercase(),
748 chrono::Utc::now().format("%Y%m%d_%H%M%S"),
749 episode
750 ));
751
752 match agent.save(&checkpoint_path) {
753 Ok(_) => {
754 let step_counts_path = checkpoint_path.with_extension("steps.json");
756 let step_counts = (global_step, total_training_steps);
757 if let Ok(step_data) = serde_json::to_string(&step_counts) {
758 let _ = std::fs::write(&step_counts_path, step_data);
759 }
760
761 let avg_reward = if metrics.episode_rewards.len() >= 100 {
762 metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
763 .iter()
764 .sum::<f32>() / 100.0
765 } else {
766 0.0
767 };
768
769 let avg_quality = if metrics.episode_qualities.len() >= 100 {
770 metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
771 .iter()
772 .sum::<f32>() / 100.0
773 } else {
774 0.0
775 };
776
777 let checkpoint = Checkpoint::new(
778 episode,
779 total_training_steps,
780 avg_reward,
781 avg_quality,
782 metrics.best_avg_quality,
783 epsilon as f32,
784 checkpoint_path.clone(),
785 );
786
787 match checkpoint_manager.save_checkpoint(&checkpoint) {
788 Ok(_) => {
789 site_memory.save_all()?;
790 info!("Improved training checkpoint saved at episode {} (global_step={})",
791 episode, global_step);
792 }
793 Err(e) => {
794 warn!("Failed to save checkpoint metadata: {}", e);
795 }
796 }
797
798 if let Ok(metadata) = std::fs::metadata(&checkpoint_path) {
799 let file_size = metadata.len();
800 if file_size < 10_000 {
801 warn!("Checkpoint file suspiciously small: {} bytes", file_size);
802 } else {
803 info!("Checkpoint saved at episode {} ({} bytes)", episode, file_size);
804 }
805 }
806 }
807 Err(e) => {
808 warn!("Failed to save model checkpoint: {}", e);
809 }
810 }
811 }
812
813 if metrics.episode_qualities.len() >= 100 {
815 let avg_quality = metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
816 .iter()
817 .sum::<f32>() / 100.0;
818
819 if avg_quality > metrics.best_avg_quality {
820 metrics.best_avg_quality = avg_quality;
821 let best_path = config.models_dir.join(format!(
822 "best_model_{}.onnx",
823 config.algorithm.to_string().to_lowercase()
824 ));
825
826 match agent.save(&best_path) {
827 Ok(_) => {
828 if let Ok(metadata) = std::fs::metadata(&best_path) {
829 info!("New best {} model saved with quality: {:.3} ({} bytes)",
830 config.algorithm, avg_quality, metadata.len());
831 } else {
832 info!("New best {} model saved with quality: {:.3}",
833 config.algorithm, avg_quality);
834 }
835 }
836 Err(e) => {
837 warn!("Failed to save best model: {}", e);
838 }
839 }
840 }
841 }
842 }
843
844 pb.finish_with_message("Improved training completed");
845
846 let final_path = config.models_dir.join(format!(
848 "final_model_{}.onnx",
849 config.algorithm.to_string().to_lowercase()
850 ));
851
852 let mut hyperparams = std::collections::HashMap::new();
853 hyperparams.insert("learning_rate".to_string(), config.learning_rate);
854 hyperparams.insert("batch_size".to_string(), config.batch_size as f64);
855 hyperparams.insert("gamma".to_string(), config.gamma);
856 hyperparams.insert("epsilon_decay".to_string(), config.epsilon_decay);
857 hyperparams.insert("target_update_freq".to_string(), config.target_update_freq as f64);
858 agent.save_with_metadata(&final_path, config.num_episodes, hyperparams)?;
859
860 if final_path.exists() {
862 let metadata = std::fs::metadata(&final_path)?;
863 info!("Final model saved: {} bytes", metadata.len());
864 }
865
866 if let Ok(model_meta) = crate::models::ModelMetadata::load_metadata(&final_path) {
868 model_meta.display();
869 }
870
871 site_memory.save_all()?;
872
873 let final_checkpoint = Checkpoint::new(
875 config.num_episodes,
876 total_training_steps,
877 metrics.episode_rewards.last().copied().unwrap_or(0.0),
878 metrics.episode_qualities.last().copied().unwrap_or(0.0),
879 metrics.best_avg_quality,
880 epsilon as f32,
881 final_path,
882 );
883 checkpoint_manager.save_checkpoint(&final_checkpoint)?;
884
885 info!("Training completed:");
886 info!(" - Total episodes: {}", config.num_episodes);
887 info!(" - Total training steps: {}", total_training_steps);
888 info!(" - Best avg quality: {:.3}", metrics.best_avg_quality);
889 info!(" - Final epsilon: {:.3}", epsilon);
890
891 Ok((agent, metrics))
892}
893
894pub fn save_training_plot(metrics: &TrainingMetrics, output_path: &Path) -> Result<()> {
896 use plotters::prelude::*;
897
898 let root = BitMapBackend::new(output_path, (1200, 800))
899 .into_drawing_area();
900 root.fill(&WHITE)
901 .map_err(|e| ExtractionError::ModelError(format!("Plot error: {}", e)))?;
902
903 let areas = root.split_evenly((2, 1));
904 let upper = &areas[0];
905 let lower = &areas[1];
906
907 let max_episodes = metrics.episode_rewards.len();
909 let max_reward = metrics.episode_rewards.iter()
910 .copied()
911 .fold(f32::NEG_INFINITY, f32::max);
912 let min_reward = metrics.episode_rewards.iter()
913 .copied()
914 .fold(f32::INFINITY, f32::min);
915
916 let mut chart = ChartBuilder::on(upper)
917 .caption("Episode Rewards", ("sans-serif", 30))
918 .margin(10)
919 .x_label_area_size(30)
920 .y_label_area_size(40)
921 .build_cartesian_2d(0..max_episodes, min_reward..max_reward)
922 .map_err(|e| ExtractionError::ModelError(format!("Chart error: {}", e)))?;
923
924 chart.configure_mesh()
925 .draw()
926 .map_err(|e| ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
927
928 chart.draw_series(LineSeries::new(
929 metrics.episode_rewards.iter().enumerate().map(|(i, &r)| (i, r)),
930 &BLUE,
931 ))
932 .map_err(|e| ExtractionError::ModelError(format!("Series error: {}", e)))?;
933
934 let max_quality = metrics.episode_qualities.iter()
936 .copied()
937 .fold(f32::NEG_INFINITY, f32::max);
938
939 let mut chart2 = ChartBuilder::on(lower)
940 .caption("Episode Quality", ("sans-serif", 30))
941 .margin(10)
942 .x_label_area_size(30)
943 .y_label_area_size(40)
944 .build_cartesian_2d(0..max_episodes, 0.0..max_quality)
945 .map_err(|e| ExtractionError::ModelError(format!("Chart error: {}", e)))?;
946
947 chart2.configure_mesh()
948 .draw()
949 .map_err(|e| ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
950
951 chart2.draw_series(LineSeries::new(
952 metrics.episode_qualities.iter().enumerate().map(|(i, &q)| (i, q)),
953 &GREEN,
954 ))
955 .map_err(|e| ExtractionError::ModelError(format!("Series error: {}", e)))?;
956
957 root.present().map_err(|e| crate::ExtractionError::IoError(
958 std::io::Error::other(e.to_string())
959 ))?;
960
961 info!("Training plot saved to: {}", output_path.display());
962
963 Ok(())
964}