Skip to main content

content_extractor_rl/
training.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/training.rs
3// ============================================================================
4
5use crate::{
6    Config, ArticleExtractionEnvironment, BaselineExtractor, ExtractionError,
7    agents::{AgentFactory, RLAgent},
8};
9
10use crate::{
11    replay_buffer::PrioritizedReplayBuffer,
12    SiteProfileMemory,
13    reward::ImprovedRewardCalculator,
14    curriculum::CurriculumManager,
15    Result,
16};
17
18use crate::environment::StepInfo;
19use indicatif::{ProgressBar, ProgressStyle};
20use std::path::Path;
21use tracing::{info, warn};
22use crate::{Checkpoint, CheckpointManager};
23use candle_nn::{VarMap};
24use candle_core::Device;
25
26
27/// Extract domain from URL
28/// The second element of html_samples is now the actual URL from JSON
29fn extract_domain_from_url(url: &str) -> String {
30    use url::Url;
31
32    // Parse the URL to extract domain
33    match Url::parse(url) {
34        Ok(parsed_url) => {
35            parsed_url.host_str()
36                .map(|h| h.to_string())
37                .unwrap_or_else(|| "unknown".to_string())
38        }
39        Err(_) => {
40            // If URL parsing fails, try to extract domain directly
41            let url = url.trim();
42            let without_protocol = url.strip_prefix("https://")
43                .or_else(|| url.strip_prefix("http://"))
44                .unwrap_or(url);
45
46            // Split by '/' to get the host part
47            let host_part = without_protocol.split('/').next().unwrap_or("");
48
49            // Split by ':' to remove port (if any)
50            let domain = host_part.split(':').next().unwrap_or("");
51
52            if domain.is_empty() {
53                "unknown".to_string()
54            } else {
55                domain.to_string()
56            }
57        }
58    }
59}
60
61
62/// Training metrics
63#[derive(Debug, Clone, Default)]  // Add Default derive
64pub struct TrainingMetrics {
65    pub episode_rewards: Vec<f32>,
66    pub episode_qualities: Vec<f32>,
67    pub episode_losses: Vec<f32>,
68    pub best_avg_quality: f32,
69}
70
71
72/// Standard training loop with checkpoint support
73pub fn train_standard(
74    config: &Config,
75    html_samples: Vec<(String, String)>,
76) -> Result<(Box<dyn RLAgent>, TrainingMetrics)> {
77    info!("Starting standard training for {} episodes", config.num_episodes);
78
79    let device = if config.use_cpu_for_tuning {
80        Device::Cpu
81    } else if crate::cuda_is_available() {
82        Device::cuda_if_available(0).unwrap_or(Device::Cpu)
83    } else {
84        Device::Cpu
85    };
86
87    let _varmap = VarMap::new();
88
89    // Initialize components
90    let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
91    let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
92    let mut replay_buffer = PrioritizedReplayBuffer::new(
93        config.replay_buffer_size,
94        config.priority_alpha,
95        config.priority_beta,
96    );
97
98    let mut agent = AgentFactory::create(
99        config.algorithm,
100        config.state_dim,
101        config.num_discrete_actions,
102        config.num_continuous_params,
103        config.gamma as f32,
104        config.learning_rate,
105        &device,
106    )?;
107
108    // TODO: correctly implement VERIFY initialization
109    // if !agent.online_network.verify_initialization()? {
110    //     return Err(ExtractionError::ModelError(
111    //         "Model initialization failed - weights are all zeros!".to_string()
112    //     ));
113    // }
114    // info!("Model initialized successfully with non-zero weights");
115
116    let mut env = ArticleExtractionEnvironment::new(baseline_extractor, config.clone());
117    let mut metrics = TrainingMetrics::default();
118    let mut epsilon = config.epsilon_start;
119
120    // Initialize checkpoint manager
121    let checkpoint_dir = config.models_dir.join("checkpoints");
122    let checkpoint_manager = CheckpointManager::new(checkpoint_dir, 5)?;
123
124    // CRITICAL FIX: Only resume if checkpoint is valid and from compatible run
125    let start_episode = match checkpoint_manager.load_latest() {
126        Ok(Some(checkpoint)) => {
127            // VALIDATION: Check if checkpoint is compatible
128            if checkpoint.episode >= config.num_episodes {
129                warn!(
130                    "Found checkpoint at episode {} but current run is only {} episodes. Starting fresh.",
131                    checkpoint.episode, config.num_episodes
132                );
133                0
134            } else if !checkpoint.model_path.exists() {
135                warn!(
136                    "Checkpoint references missing model file: {}. Starting fresh.",
137                    checkpoint.model_path.display()
138                );
139                0
140            } else {
141                // Try to load the model
142                info!("Found checkpoint at episode {}, attempting to load...", checkpoint.episode);
143
144                match AgentFactory::load(
145                    &checkpoint.model_path,
146                    config.state_dim,
147                    config.num_discrete_actions,
148                    config.num_continuous_params,
149                    &device,
150                ) {
151                    Ok(loaded_agent) => {
152                        agent = loaded_agent;
153                        epsilon = checkpoint.epsilon as f64;
154                        metrics.best_avg_quality = checkpoint.best_quality;
155                        info!("Successfully resumed from checkpoint at episode {}", checkpoint.episode);
156                        checkpoint.episode
157                    }
158                    Err(e) => {
159                        warn!("Failed to load checkpoint model: {}. Starting fresh.", e);
160                        warn!("Consider deleting checkpoint directory if corruption persists.");
161                        0
162                    }
163                }
164            }
165        }
166        Ok(None) => {
167            info!("No checkpoint found, starting fresh training");
168            0
169        }
170        Err(e) => {
171            warn!("Error loading checkpoint: {}. Starting fresh.", e);
172            0
173        }
174    };
175
176    // Progress bar
177    let pb = ProgressBar::new((config.num_episodes - start_episode) as u64);
178    pb.set_style(
179        ProgressStyle::default_bar()
180            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
181            .unwrap()
182            .progress_chars("=>-"),
183    );
184
185    for episode in start_episode..config.num_episodes {
186        // Sample HTML
187        let idx = episode % html_samples.len();
188        let (html, url) = &html_samples[idx];
189
190        let domain = extract_domain_from_url(url);
191
192        let site_profile = site_memory.get_profile(&domain);
193
194        // Reset environment
195        let mut state = env.reset(html, url.clone(), Some(site_profile))?;
196
197        let mut episode_reward = 0.0;
198        let mut done = false;
199        let mut step_info = StepInfo {
200            quality_score: 0.0,
201            text: String::new(),
202            xpath: String::new(),
203            parameters: std::collections::HashMap::new(),
204            step_count: 0,
205        };
206
207        // Episode loop
208        while !done {
209            let action = agent.select_action(&state, epsilon as f32)?;
210            let (next_state, reward, is_done, info) = env.step(action.clone())?;
211
212            episode_reward += reward;
213            done = is_done;
214            step_info = info;
215
216            // Store experience
217            let experience = crate::replay_buffer::Experience {
218                state: state.clone(),
219                action,
220                reward,
221                next_state: next_state.clone(),
222                done,
223            };
224            replay_buffer.add(experience);
225
226            // Training step
227            if replay_buffer.len() > config.batch_size * 10 {
228                let loss = agent.train_step(&mut replay_buffer, config.batch_size)?;
229                metrics.episode_losses.push(loss);
230            }
231
232            state = next_state;
233        }
234
235        // Update site profile
236        let profile = site_memory.get_profile(&domain);
237        let extraction_result = crate::site_profile::ExtractionResult {
238            text: step_info.text.clone(),
239            xpath: step_info.xpath.clone(),
240            quality_score: step_info.quality_score,
241            parameters: step_info.parameters.clone(),
242            title: None,
243            date: None,
244        };
245        profile.add_extraction(extraction_result);
246
247        // Decay epsilon
248        epsilon *= config.epsilon_decay;
249        epsilon = epsilon.max(config.epsilon_end);
250
251        // Update target network
252        if episode % config.target_update_freq == 0 {
253            agent.update_target_network();
254        }
255
256        // Record metrics
257        metrics.episode_rewards.push(episode_reward);
258        metrics.episode_qualities.push(step_info.quality_score);
259
260        // Update progress bar
261        if episode % 10 == 0 {
262            let avg_reward = if metrics.episode_rewards.len() >= 100 {
263                metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
264                    .iter()
265                    .sum::<f32>() / 100.0
266            } else {
267                episode_reward
268            };
269
270            pb.set_message(format!(
271                "Reward: {:.3}, Quality: {:.3}",
272                avg_reward, step_info.quality_score
273            ));
274        }
275        pb.inc(1);
276
277        // Save checkpoint every checkpoint_freq episodes (only for long runs)
278        if episode % config.checkpoint_freq == 0 && episode > 0 && config.num_episodes >= 5000 {
279            let checkpoint_path = config.models_dir.join(format!(
280                "checkpoint_{}_{}_ep{}.onnx",
281                config.algorithm.to_string().to_lowercase(),
282                chrono::Utc::now().format("%Y%m%d_%H%M%S"),
283                episode
284            ));
285
286            // Validate save was successful
287            match agent.save(&checkpoint_path) {
288                Ok(_) => {
289                    let avg_reward = if metrics.episode_rewards.len() >= 100 {
290                        metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
291                            .iter()
292                            .sum::<f32>() / 100.0
293                    } else {
294                        0.0
295                    };
296
297                    let avg_quality = if metrics.episode_qualities.len() >= 100 {
298                        metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
299                            .iter()
300                            .sum::<f32>() / 100.0
301                    } else {
302                        0.0
303                    };
304
305                    let checkpoint = Checkpoint::new(
306                        episode,
307                        agent.get_step_count(),
308                        avg_reward,
309                        avg_quality,
310                        metrics.best_avg_quality,
311                        epsilon as f32,
312                        checkpoint_path.clone(),
313                    );
314
315                    match checkpoint_manager.save_checkpoint(&checkpoint) {
316                        Ok(_) => {
317                            if checkpoint_path.exists() {
318                                let metadata = std::fs::metadata(&checkpoint_path)?;
319                                if metadata.len() > 0 {
320                                    site_memory.save_all()?;
321                                    info!("Checkpoint saved at episode {} ({} bytes)", episode, metadata.len());
322                                } else {
323                                    warn!("Checkpoint file is empty, may be corrupted");
324                                }
325                            } else {
326                                warn!("Checkpoint file disappeared after save");
327                            }
328                        }
329                        Err(e) => {
330                            warn!("Failed to save checkpoint metadata: {}", e);
331                        }
332                    }
333                }
334                Err(e) => {
335                    warn!("Failed to save model checkpoint: {}", e);
336                }
337            }
338        }
339    }
340
341    pb.finish_with_message("Training completed");
342
343    // Save final model with validation, metadata and with algorithm name
344    let final_path = config.models_dir.join(format!(
345        "final_model_{}.onnx",
346        config.algorithm.to_string().to_lowercase()
347    ));
348
349    let mut hyperparams = std::collections::HashMap::new();
350    hyperparams.insert("learning_rate".to_string(), config.learning_rate);
351    hyperparams.insert("batch_size".to_string(), config.batch_size as f64);
352    hyperparams.insert("gamma".to_string(), config.gamma);
353    hyperparams.insert("epsilon_decay".to_string(), config.epsilon_decay);
354    hyperparams.insert("target_update_freq".to_string(), config.target_update_freq as f64);
355    agent.save_with_metadata(&final_path, config.num_episodes, hyperparams)?;
356
357    // Verify final save
358    if final_path.exists() {
359        let metadata = std::fs::metadata(&final_path)?;
360        info!("Final model saved: {} bytes", metadata.len());
361    }
362
363    // Display metadata
364    if let Ok(model_meta) = crate::models::ModelMetadata::load_metadata(&final_path) {
365        model_meta.display();
366    }
367
368    site_memory.save_all()?;
369
370    // Save final checkpoint with algorithm-specific path
371    let final_checkpoint = Checkpoint::new(
372        config.num_episodes,
373        agent.get_step_count(),
374        metrics.episode_rewards.last().copied().unwrap_or(0.0),
375        metrics.episode_qualities.last().copied().unwrap_or(0.0),
376        metrics.best_avg_quality,
377        epsilon as f32,
378        final_path,
379    );
380    checkpoint_manager.save_checkpoint(&final_checkpoint)?;
381
382    info!("Training completed. Best avg quality: {:.3}", metrics.best_avg_quality);
383
384    Ok((agent, metrics))
385}
386
387
388/// Training with improvements (curriculum learning, improved rewards, domain extraction, etc.)
389pub fn train_with_improvements(
390    config: &Config,
391    html_samples: Vec<(String, String)>,
392) -> Result<(Box<dyn RLAgent>, TrainingMetrics)> {
393    info!("Starting OPTIMIZED training for {} episodes", config.num_episodes);
394    info!("Performance settings:");
395    info!("  - Batch size: {}", config.batch_size);
396    info!("  - Train frequency: every {} steps", config.train_freq);
397    info!("  - Gradient updates per episode: {}", config.num_train_steps_per_episode);
398    info!("  - Min replay size: {}", config.min_replay_size);
399    info!("  - Metrics window: {}", config.metrics_window);
400    info!("  - Dataset size: {}", html_samples.len());
401
402    // Initialize device and varbuilder
403    let device = if config.use_cpu_for_tuning {
404        Device::Cpu  // Force CPU for hyperparameter tuning
405    } else if crate::cuda_is_available() {
406        Device::cuda_if_available(0).unwrap_or(Device::Cpu)
407    } else {
408        Device::Cpu
409    };
410
411    // step counters:
412    let mut global_step:usize = 0;
413    let mut total_training_steps:usize = 0;
414
415    // Initialize components
416    let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
417    let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
418    let mut replay_buffer = PrioritizedReplayBuffer::new(
419        config.replay_buffer_size,
420        config.priority_alpha,
421        config.priority_beta,
422    );
423
424    // varmap is created internally by AgentFactory
425    let mut agent = AgentFactory::create(
426        config.algorithm,
427        config.state_dim,
428        config.num_discrete_actions,
429        config.num_continuous_params,
430        config.gamma as f32,
431        config.learning_rate,
432        &device,
433    )?;
434
435    let mut env = ArticleExtractionEnvironment::new(baseline_extractor.clone(), config.clone());
436    let mut metrics = TrainingMetrics { episode_rewards: vec![], episode_qualities: vec![], episode_losses: vec![], best_avg_quality: 0.0 };
437
438    // Enhanced components
439    let reward_calculator = ImprovedRewardCalculator::new(config.stopwords.clone());
440    let mut curriculum = CurriculumManager::new();
441    let mut epsilon = config.epsilon_start;
442
443    // ADDED: Checkpoint manager for improved training
444    let checkpoint_dir = config.models_dir.join("checkpoints");
445    let checkpoint_manager = CheckpointManager::new(checkpoint_dir, 5)?;
446
447    // ADDED: Resume logic similar to train_standard
448    let start_episode = match checkpoint_manager.load_latest() {
449        Ok(Some(checkpoint)) => {
450            // Validate checkpoint compatibility
451            if checkpoint.episode >= config.num_episodes {
452                warn!(
453                    "Found checkpoint at episode {} but current run is only {} episodes. Starting fresh.",
454                    checkpoint.episode, config.num_episodes
455                );
456                0
457            } else if !checkpoint.model_path.exists() {
458                warn!(
459                    "Checkpoint references missing model file: {}. Starting fresh.",
460                    checkpoint.model_path.display()
461                );
462                0
463            } else {
464                // Try to load the model
465                info!("Found checkpoint at episode {}, attempting to load...", checkpoint.episode);
466
467                match AgentFactory::load(
468                    &checkpoint.model_path,
469                    config.state_dim,
470                    config.num_discrete_actions,
471                    config.num_continuous_params,
472                    &device,
473                ) {
474                    Ok(loaded_agent) => {
475                        agent = loaded_agent;
476                        epsilon = checkpoint.epsilon as f64;
477                        metrics.best_avg_quality = checkpoint.best_quality;
478
479                        // Try to load step counts from a separate file
480                        let step_counts_path = checkpoint.model_path.with_extension("steps.json");
481                        if step_counts_path.exists() {
482                            if let Ok(step_data) = std::fs::read_to_string(&step_counts_path) {
483                                if let Ok(step_counts) = serde_json::from_str::<(usize, usize)>(&step_data) {
484                                    global_step = step_counts.0;
485                                    total_training_steps = step_counts.1;
486                                    info!("Resumed step counts: global_step={}, total_training_steps={}",
487                                          global_step, total_training_steps);
488                                }
489                            }
490                        }
491
492                        info!("Successfully resumed from checkpoint at episode {}", checkpoint.episode);
493                        checkpoint.episode
494                    }
495                    Err(e) => {
496                        warn!("Failed to load checkpoint model: {}. Starting fresh.", e);
497                        warn!("Consider deleting checkpoint directory if corruption persists.");
498                        0
499                    }
500                }
501            }
502        }
503        Ok(None) => {
504            info!("No checkpoint found, starting fresh training");
505            0
506        }
507        Err(e) => {
508            warn!("Error loading checkpoint: {}. Starting fresh.", e);
509            0
510        }
511    };
512
513    // Progress bar - start from resume point
514    let pb = ProgressBar::new((config.num_episodes - start_episode) as u64);
515    pb.set_style(
516        ProgressStyle::default_bar()
517            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
518            .unwrap()
519            .progress_chars("█▓▒░"),
520    );
521
522    for episode in start_episode..config.num_episodes {
523        let mut _episode_training_steps:usize = 0;
524        // Update curriculum
525        curriculum.update_threshold(episode);
526
527        // Sample HTML (with curriculum filtering)
528        let appropriate_samples: Vec<_> = html_samples.iter()
529            .filter(|(html, _)| curriculum.is_appropriate(html))
530            .collect();
531
532        if appropriate_samples.is_empty() {
533            warn!("No appropriate HTML samples for current curriculum");
534            break;
535        }
536
537        let idx = episode % appropriate_samples.len();
538        let (html, file_path) = appropriate_samples[idx];
539
540        // Extract domain from ground truth JSON
541        let domain = extract_domain_from_url(file_path);
542
543        // Log domain extraction (first few episodes for verification)
544        if episode < 10 {
545            info!("Episode {}: File: {}, Domain: {}", episode, file_path, domain);
546        }
547
548        let site_profile = site_memory.get_profile(&domain);
549
550        // Get baseline score
551        let baseline_result = baseline_extractor.extract(html)?;
552        let baseline_score = baseline_result.quality_score;
553
554        // Reset environment
555        let mut state = env.reset(html, file_path.clone(), Some(site_profile))?;
556
557        let mut episode_reward = 0.0;
558        let mut done = false;
559        let mut step_info = StepInfo {
560            quality_score: 0.0,
561            text: String::new(),
562            xpath: String::new(),
563            parameters: std::collections::HashMap::new(),
564            step_count: 0,
565        };
566
567        // Episode loop
568        while !done {
569            let action = agent.select_action(&state, epsilon as f32)?;
570            let (next_state, _, is_done, info) = env.step(action.clone())?;
571
572            // Calculate improved reward
573            let reward = reward_calculator.calculate_reward(&info.text, baseline_score);
574
575            episode_reward += reward;
576            done = is_done;
577            step_info = info;
578            // Increment global step counter
579            global_step += 1;
580
581            // Store experience
582            let experience = crate::replay_buffer::Experience {
583                state: state.clone(),
584                action,
585                reward,
586                next_state: next_state.clone(),
587                done,
588            };
589            replay_buffer.add(experience);
590
591            // OPTIMIZED: More frequent training after warmup
592            if replay_buffer.len() >= config.min_replay_size &&
593                global_step.is_multiple_of(config.train_freq) {
594                // ADDED: Robust error handling for training step
595                match agent.train_step(&mut replay_buffer, config.batch_size) {
596                    Ok(loss) => {
597                        // Check for NaN or infinite loss
598                        if loss.is_nan() || loss.is_infinite() {
599                            warn!("Invalid loss detected at episode {}, step {}: {}", episode, global_step, loss);
600                            warn!("Skipping this training step");
601                        } else {
602                            metrics.episode_losses.push(loss);
603                            _episode_training_steps += 1;
604                        }
605                    }
606                    Err(e) => {
607                        warn!("Training step failed at episode {}, step {}: {}", episode, global_step, e);
608                        warn!("Continuing training...");
609                        // Don't fail the entire run for one bad batch
610                    }
611                }
612            }
613
614            state = next_state;
615        }
616
617        // OPTIMIZED: Multiple gradient updates per episode
618        if replay_buffer.len() >= config.min_replay_size {
619            for update_idx in 0..config.num_train_steps_per_episode {
620                match agent.train_step(&mut replay_buffer, config.batch_size) {
621                    Ok(loss) => {
622                        if loss.is_nan() || loss.is_infinite() {
623                            warn!("Invalid loss in gradient update {} at episode {}", update_idx, episode);
624                            break; // Stop further updates this episode
625                        }
626                        metrics.episode_losses.push(loss);
627                        total_training_steps += 1;
628                    }
629                    Err(e) => {
630                        warn!("Gradient update {} failed at episode {}: {}", update_idx, episode, e);
631                        break; // Stop further updates this episode
632                    }
633                }
634            }
635        }
636
637        // Update site profile with correct domain
638        let profile = site_memory.get_profile(&domain);
639        let extraction_result = crate::site_profile::ExtractionResult {
640            text: step_info.text.clone(),
641            xpath: step_info.xpath.clone(),
642            quality_score: step_info.quality_score,
643            parameters: step_info.parameters.clone(),
644            title: None,
645            date: None,
646        };
647        profile.add_extraction(extraction_result);
648        // Save site profiles periodically
649        if episode % 100 == 0 && episode > 0 {
650            match site_memory.save_all() {
651                Ok(_) => {
652                    if episode % 500 == 0 {
653                        info!("Site profiles saved at episode {}", episode);
654                    }
655                }
656                Err(e) => {
657                    warn!("Failed to save site profiles: {}", e);
658                }
659            }
660        }
661
662        // Decay epsilon (exponential)
663        let progress = (episode as f32 / 2000.0).min(1.0);
664        epsilon = config.epsilon_start * (config.epsilon_end / config.epsilon_start).powf(progress as f64);
665        epsilon = epsilon.max(config.epsilon_end);
666
667        // Update target network
668        if episode % config.target_update_freq == 0 {
669            agent.update_target_network();
670        }
671
672        // Record metrics
673        metrics.episode_rewards.push(episode_reward);
674        metrics.episode_qualities.push(step_info.quality_score);
675
676        // Update progress bar
677        if episode % config.log_freq == 0 {
678            let window = config.metrics_window;
679            let avg_reward = if metrics.episode_rewards.len() >= window {
680                metrics.episode_rewards[metrics.episode_rewards.len() - window..]
681                    .iter()
682                    .sum::<f32>() / window as f32
683            } else if !metrics.episode_rewards.is_empty() {
684                metrics.episode_rewards.iter().sum::<f32>() / metrics.episode_rewards.len() as f32
685            } else {
686                0.0
687            };
688
689            let avg_quality = if metrics.episode_qualities.len() >= window {
690                metrics.episode_qualities[metrics.episode_qualities.len() - window..]
691                    .iter()
692                    .sum::<f32>() / window as f32
693            } else if !metrics.episode_qualities.is_empty() {
694                metrics.episode_qualities.iter().sum::<f32>() / metrics.episode_qualities.len() as f32
695            } else {
696                0.0
697            };
698
699            let curriculum_threshold = curriculum.get_threshold();
700            pb.set_message(format!(
701                "R:{:.2} Q:{:.3} ε:{:.3} C:{:.2} Steps:{}",
702                avg_reward, avg_quality, epsilon, curriculum_threshold, total_training_steps
703            ));
704        }
705        pb.inc(1);
706
707        // Save checkpoint every 500 episodes (more frequent for safety)
708        if episode % config.checkpoint_freq == 0 && episode > 0 {
709            let checkpoint_path = config.models_dir.join(format!(
710                "checkpoint_{}_{}_ep{}.onnx",
711                config.algorithm.to_string().to_lowercase(),
712                chrono::Utc::now().format("%Y%m%d_%H%M%S"),
713                episode
714            ));
715
716            match agent.save(&checkpoint_path) {
717                Ok(_) => {
718                    // Save step counts alongside model
719                    let step_counts_path = checkpoint_path.with_extension("steps.json");
720                    let step_counts = (global_step, total_training_steps);
721                    if let Ok(step_data) = serde_json::to_string(&step_counts) {
722                        let _ = std::fs::write(&step_counts_path, step_data);
723                    }
724
725                    let avg_reward = if metrics.episode_rewards.len() >= 100 {
726                        metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
727                            .iter()
728                            .sum::<f32>() / 100.0
729                    } else {
730                        0.0
731                    };
732
733                    let avg_quality = if metrics.episode_qualities.len() >= 100 {
734                        metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
735                            .iter()
736                            .sum::<f32>() / 100.0
737                    } else {
738                        0.0
739                    };
740
741                    let checkpoint = Checkpoint::new(
742                        episode,
743                        total_training_steps,
744                        avg_reward,
745                        avg_quality,
746                        metrics.best_avg_quality,
747                        epsilon as f32,
748                        checkpoint_path.clone(),
749                    );
750
751                    match checkpoint_manager.save_checkpoint(&checkpoint) {
752                        Ok(_) => {
753                            site_memory.save_all()?;
754                            info!("Improved training checkpoint saved at episode {} (global_step={})",
755                                  episode, global_step);
756                        }
757                        Err(e) => {
758                            warn!("Failed to save checkpoint metadata: {}", e);
759                        }
760                    }
761
762                    if let Ok(metadata) = std::fs::metadata(&checkpoint_path) {
763                        let file_size = metadata.len();
764                        if file_size < 10_000 {
765                            warn!("Checkpoint file suspiciously small: {} bytes", file_size);
766                        } else {
767                            info!("Checkpoint saved at episode {} ({} bytes)", episode, file_size);
768                        }
769                    }
770                }
771                Err(e) => {
772                    warn!("Failed to save model checkpoint: {}", e);
773                }
774            }
775        }
776
777        // Track best model with algorithm-specific name
778        if metrics.episode_qualities.len() >= 100 {
779            let avg_quality = metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
780                .iter()
781                .sum::<f32>() / 100.0;
782
783            if avg_quality > metrics.best_avg_quality {
784                metrics.best_avg_quality = avg_quality;
785                let best_path = config.models_dir.join(format!(
786                    "best_model_{}.onnx",
787                    config.algorithm.to_string().to_lowercase()
788                ));
789
790                match agent.save(&best_path) {
791                    Ok(_) => {
792                        if let Ok(metadata) = std::fs::metadata(&best_path) {
793                            info!("New best {} model saved with quality: {:.3} ({} bytes)",
794                                  config.algorithm, avg_quality, metadata.len());
795                        } else {
796                            info!("New best {} model saved with quality: {:.3}",
797                                  config.algorithm, avg_quality);
798                        }
799                    }
800                    Err(e) => {
801                        warn!("Failed to save best model: {}", e);
802                    }
803                }
804            }
805        }
806    }
807
808    pb.finish_with_message("Improved training completed");
809
810    // Save final model with validation, metadata and algorithm name
811    let final_path = config.models_dir.join(format!(
812        "final_model_{}.onnx",
813        config.algorithm.to_string().to_lowercase()
814    ));
815
816    let mut hyperparams = std::collections::HashMap::new();
817    hyperparams.insert("learning_rate".to_string(), config.learning_rate);
818    hyperparams.insert("batch_size".to_string(), config.batch_size as f64);
819    hyperparams.insert("gamma".to_string(), config.gamma);
820    hyperparams.insert("epsilon_decay".to_string(), config.epsilon_decay);
821    hyperparams.insert("target_update_freq".to_string(), config.target_update_freq as f64);
822    agent.save_with_metadata(&final_path, config.num_episodes, hyperparams)?;
823
824    // Verify final save
825    if final_path.exists() {
826        let metadata = std::fs::metadata(&final_path)?;
827        info!("Final model saved: {} bytes", metadata.len());
828    }
829
830    // Display metadata
831    if let Ok(model_meta) = crate::models::ModelMetadata::load_metadata(&final_path) {
832        model_meta.display();
833    }
834
835    site_memory.save_all()?;
836
837    // Save final checkpoint
838    let final_checkpoint = Checkpoint::new(
839        config.num_episodes,
840        total_training_steps,
841        metrics.episode_rewards.last().copied().unwrap_or(0.0),
842        metrics.episode_qualities.last().copied().unwrap_or(0.0),
843        metrics.best_avg_quality,
844        epsilon as f32,
845        final_path,
846    );
847    checkpoint_manager.save_checkpoint(&final_checkpoint)?;
848
849    info!("Training completed:");
850    info!("  - Total episodes: {}", config.num_episodes);
851    info!("  - Total training steps: {}", total_training_steps);
852    info!("  - Best avg quality: {:.3}", metrics.best_avg_quality);
853    info!("  - Final epsilon: {:.3}", epsilon);
854
855    Ok((agent, metrics))
856}
857
858/// Save training plot
859pub fn save_training_plot(metrics: &TrainingMetrics, output_path: &Path) -> Result<()> {
860    use plotters::prelude::*;
861
862    let root = BitMapBackend::new(output_path, (1200, 800))
863        .into_drawing_area();
864    root.fill(&WHITE)
865        .map_err(|e| ExtractionError::ModelError(format!("Plot error: {}", e)))?;
866
867    let areas = root.split_evenly((2, 1));
868    let upper = &areas[0];
869    let lower = &areas[1];
870
871    // Plot rewards
872    let max_episodes = metrics.episode_rewards.len();
873    let max_reward = metrics.episode_rewards.iter()
874        .copied()
875        .fold(f32::NEG_INFINITY, f32::max);
876    let min_reward = metrics.episode_rewards.iter()
877        .copied()
878        .fold(f32::INFINITY, f32::min);
879
880    let mut chart = ChartBuilder::on(upper)
881        .caption("Episode Rewards", ("sans-serif", 30))
882        .margin(10)
883        .x_label_area_size(30)
884        .y_label_area_size(40)
885        .build_cartesian_2d(0..max_episodes, min_reward..max_reward)
886        .map_err(|e| ExtractionError::ModelError(format!("Chart error: {}", e)))?;
887
888    chart.configure_mesh()
889        .draw()
890        .map_err(|e| ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
891
892    chart.draw_series(LineSeries::new(
893        metrics.episode_rewards.iter().enumerate().map(|(i, &r)| (i, r)),
894        &BLUE,
895    ))
896        .map_err(|e| ExtractionError::ModelError(format!("Series error: {}", e)))?;
897
898    // Plot qualities
899    let max_quality = metrics.episode_qualities.iter()
900        .copied()
901        .fold(f32::NEG_INFINITY, f32::max);
902
903    let mut chart2 = ChartBuilder::on(lower)
904        .caption("Episode Quality", ("sans-serif", 30))
905        .margin(10)
906        .x_label_area_size(30)
907        .y_label_area_size(40)
908        .build_cartesian_2d(0..max_episodes, 0.0..max_quality)
909        .map_err(|e| ExtractionError::ModelError(format!("Chart error: {}", e)))?;
910
911    chart2.configure_mesh()
912        .draw()
913        .map_err(|e| ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
914
915    chart2.draw_series(LineSeries::new(
916        metrics.episode_qualities.iter().enumerate().map(|(i, &q)| (i, q)),
917        &GREEN,
918    ))
919        .map_err(|e| ExtractionError::ModelError(format!("Series error: {}", e)))?;
920
921    root.present().map_err(|e| crate::ExtractionError::IoError(
922        std::io::Error::other(e.to_string())
923    ))?;
924
925    info!("Training plot saved to: {}", output_path.display());
926
927    Ok(())
928}