Skip to main content

content_extractor_rl/
training.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/training.rs
3// ============================================================================
4
5use crate::{
6    Config, ArticleExtractionEnvironment, BaselineExtractor, ExtractionError,
7    agents::{AgentFactory, RLAgent},
8};
9
10use crate::{
11    replay_buffer::PrioritizedReplayBuffer,
12    SiteProfileMemory,
13    curriculum::CurriculumManager,
14    Result,
15};
16
17use crate::environment::StepInfo;
18use rand::RngExt;
19use indicatif::{ProgressBar, ProgressStyle};
20use std::path::Path;
21use tracing::{info, warn};
22use crate::{Checkpoint, CheckpointManager};
23use candle_nn::{VarMap};
24use candle_core::Device;
25
26
27/// Extract domain from URL
28/// The second element of html_samples is now the actual URL from JSON
29fn extract_domain_from_url(url: &str) -> String {
30    use url::Url;
31
32    // Parse the URL to extract domain
33    match Url::parse(url) {
34        Ok(parsed_url) => {
35            parsed_url.host_str()
36                .map(|h| h.to_string())
37                .unwrap_or_else(|| "unknown".to_string())
38        }
39        Err(_) => {
40            // If URL parsing fails, try to extract domain directly
41            let url = url.trim();
42            let without_protocol = url.strip_prefix("https://")
43                .or_else(|| url.strip_prefix("http://"))
44                .unwrap_or(url);
45
46            // Split by '/' to get the host part
47            let host_part = without_protocol.split('/').next().unwrap_or("");
48
49            // Split by ':' to remove port (if any)
50            let domain = host_part.split(':').next().unwrap_or("");
51
52            if domain.is_empty() {
53                "unknown".to_string()
54            } else {
55                domain.to_string()
56            }
57        }
58    }
59}
60
61
62/// A single training example: the page HTML, its URL, and (when available) the
63/// ground-truth article text used to compute a token-F1 reward.
64///
65/// `From<(String, String)>` is provided so existing call sites that only have
66/// `(html, url)` pairs keep working — they simply train against the
67/// self-supervised text-quality proxy (ground truth `None`).
68#[derive(Debug, Clone)]
69pub struct TrainingSample {
70    pub html: String,
71    pub url: String,
72    pub ground_truth_text: Option<String>,
73}
74
75impl TrainingSample {
76    /// Construct a sample with ground-truth article text.
77    pub fn with_ground_truth(html: String, url: String, ground_truth_text: String) -> Self {
78        Self { html, url, ground_truth_text: Some(ground_truth_text) }
79    }
80}
81
82impl From<(String, String)> for TrainingSample {
83    fn from((html, url): (String, String)) -> Self {
84        Self { html, url, ground_truth_text: None }
85    }
86}
87
88/// Training metrics
89#[derive(Debug, Clone, Default)]  // Add Default derive
90pub struct TrainingMetrics {
91    pub episode_rewards: Vec<f32>,
92    pub episode_qualities: Vec<f32>,
93    pub episode_losses: Vec<f32>,
94    pub best_avg_quality: f32,
95}
96
97
98/// Standard training loop with checkpoint support
99pub fn train_standard(
100    config: &Config,
101    html_samples: Vec<TrainingSample>,
102) -> Result<(Box<dyn RLAgent>, TrainingMetrics)> {
103    info!("Starting standard training for {} episodes", config.num_episodes);
104
105    let device = if config.use_cpu_for_tuning {
106        Device::Cpu
107    } else if crate::cuda_is_available() {
108        Device::cuda_if_available(0).unwrap_or(Device::Cpu)
109    } else {
110        Device::Cpu
111    };
112
113    let _varmap = VarMap::new();
114
115    // Initialize components
116    let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
117    let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
118    let mut replay_buffer = PrioritizedReplayBuffer::new(
119        config.replay_buffer_size,
120        config.priority_alpha,
121        config.priority_beta,
122    );
123
124    let mut agent = AgentFactory::create(
125        config.algorithm,
126        config.state_dim,
127        config.num_discrete_actions,
128        config.num_continuous_params,
129        config.gamma as f32,
130        config.learning_rate,
131        &device,
132    )?;
133
134    // TODO: correctly implement VERIFY initialization
135    // if !agent.online_network.verify_initialization()? {
136    //     return Err(ExtractionError::ModelError(
137    //         "Model initialization failed - weights are all zeros!".to_string()
138    //     ));
139    // }
140    // info!("Model initialized successfully with non-zero weights");
141
142    let mut env = ArticleExtractionEnvironment::new(baseline_extractor, config.clone());
143    let mut metrics = TrainingMetrics::default();
144    let mut epsilon = config.epsilon_start;
145
146    // Initialize checkpoint manager
147    let checkpoint_dir = config.models_dir.join("checkpoints");
148    let checkpoint_manager = CheckpointManager::new(checkpoint_dir, 5)?;
149
150    // CRITICAL FIX: Only resume if checkpoint is valid and from compatible run
151    let start_episode = match checkpoint_manager.load_latest() {
152        Ok(Some(checkpoint)) => {
153            // VALIDATION: Check if checkpoint is compatible
154            if checkpoint.episode >= config.num_episodes {
155                warn!(
156                    "Found checkpoint at episode {} but current run is only {} episodes. Starting fresh.",
157                    checkpoint.episode, config.num_episodes
158                );
159                0
160            } else if !checkpoint.model_path.exists() {
161                warn!(
162                    "Checkpoint references missing model file: {}. Starting fresh.",
163                    checkpoint.model_path.display()
164                );
165                0
166            } else {
167                // Try to load the model
168                info!("Found checkpoint at episode {}, attempting to load...", checkpoint.episode);
169
170                match AgentFactory::load(
171                    &checkpoint.model_path,
172                    config.state_dim,
173                    config.num_discrete_actions,
174                    config.num_continuous_params,
175                    &device,
176                ) {
177                    Ok(loaded_agent) => {
178                        agent = loaded_agent;
179                        epsilon = checkpoint.epsilon as f64;
180                        metrics.best_avg_quality = checkpoint.best_quality;
181                        info!("Successfully resumed from checkpoint at episode {}", checkpoint.episode);
182                        checkpoint.episode
183                    }
184                    Err(e) => {
185                        warn!("Failed to load checkpoint model: {}. Starting fresh.", e);
186                        warn!("Consider deleting checkpoint directory if corruption persists.");
187                        0
188                    }
189                }
190            }
191        }
192        Ok(None) => {
193            info!("No checkpoint found, starting fresh training");
194            0
195        }
196        Err(e) => {
197            warn!("Error loading checkpoint: {}. Starting fresh.", e);
198            0
199        }
200    };
201
202    // Progress bar
203    let pb = ProgressBar::new((config.num_episodes - start_episode) as u64);
204    pb.set_style(
205        ProgressStyle::default_bar()
206            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
207            .unwrap()
208            .progress_chars("=>-"),
209    );
210
211    for episode in start_episode..config.num_episodes {
212        // Sample HTML at random to decorrelate consecutive experiences (the
213        // old `episode % len` cycling produced a fixed visitation order).
214        let idx = rand::rng().random_range(0..html_samples.len());
215        let sample = &html_samples[idx];
216        let (html, url) = (&sample.html, &sample.url);
217
218        let domain = extract_domain_from_url(url);
219
220        let site_profile = site_memory.get_profile(&domain);
221
222        // Reset environment
223        let mut state = env.reset(
224            html,
225            url.clone(),
226            sample.ground_truth_text.as_deref(),
227            Some(site_profile),
228        )?;
229
230        let mut episode_reward = 0.0;
231        let mut done = false;
232        let mut step_info = StepInfo {
233            quality_score: 0.0,
234            text: String::new(),
235            xpath: String::new(),
236            parameters: std::collections::HashMap::new(),
237            step_count: 0,
238        };
239
240        // Episode loop
241        while !done {
242            let action = agent.select_action(&state, epsilon as f32)?;
243            let (next_state, reward, is_done, info) = env.step(action.clone())?;
244
245            episode_reward += reward;
246            done = is_done;
247            step_info = info;
248
249            // Store experience
250            let experience = crate::replay_buffer::Experience {
251                state: state.clone(),
252                action,
253                reward,
254                next_state: next_state.clone(),
255                done,
256            };
257            replay_buffer.add(experience);
258
259            // Training step
260            if replay_buffer.len() > config.batch_size * 10 {
261                let loss = agent.train_step(&mut replay_buffer, config.batch_size)?;
262                metrics.episode_losses.push(loss);
263            }
264
265            state = next_state;
266        }
267
268        // Update site profile
269        let profile = site_memory.get_profile(&domain);
270        let extraction_result = crate::site_profile::ExtractionResult {
271            text: step_info.text.clone(),
272            xpath: step_info.xpath.clone(),
273            quality_score: step_info.quality_score,
274            parameters: step_info.parameters.clone(),
275            title: None,
276            date: None,
277        };
278        profile.add_extraction(extraction_result);
279
280        // Decay epsilon
281        epsilon *= config.epsilon_decay;
282        epsilon = epsilon.max(config.epsilon_end);
283
284        // Update target network
285        if episode % config.target_update_freq == 0 {
286            agent.update_target_network();
287        }
288
289        // Record metrics
290        metrics.episode_rewards.push(episode_reward);
291        metrics.episode_qualities.push(step_info.quality_score);
292
293        // Update progress bar
294        if episode % 10 == 0 {
295            let avg_reward = if metrics.episode_rewards.len() >= 100 {
296                metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
297                    .iter()
298                    .sum::<f32>() / 100.0
299            } else {
300                episode_reward
301            };
302
303            pb.set_message(format!(
304                "Reward: {:.3}, Quality: {:.3}",
305                avg_reward, step_info.quality_score
306            ));
307        }
308        pb.inc(1);
309
310        // Save checkpoint every checkpoint_freq episodes (only for long runs)
311        if episode % config.checkpoint_freq == 0 && episode > 0 && config.num_episodes >= 5000 {
312            let checkpoint_path = config.models_dir.join(format!(
313                "checkpoint_{}_{}_ep{}.onnx",
314                config.algorithm.to_string().to_lowercase(),
315                chrono::Utc::now().format("%Y%m%d_%H%M%S"),
316                episode
317            ));
318
319            // Validate save was successful
320            match agent.save(&checkpoint_path) {
321                Ok(_) => {
322                    let avg_reward = if metrics.episode_rewards.len() >= 100 {
323                        metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
324                            .iter()
325                            .sum::<f32>() / 100.0
326                    } else {
327                        0.0
328                    };
329
330                    let avg_quality = if metrics.episode_qualities.len() >= 100 {
331                        metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
332                            .iter()
333                            .sum::<f32>() / 100.0
334                    } else {
335                        0.0
336                    };
337
338                    let checkpoint = Checkpoint::new(
339                        episode,
340                        agent.get_step_count(),
341                        avg_reward,
342                        avg_quality,
343                        metrics.best_avg_quality,
344                        epsilon as f32,
345                        checkpoint_path.clone(),
346                    );
347
348                    match checkpoint_manager.save_checkpoint(&checkpoint) {
349                        Ok(_) => {
350                            if checkpoint_path.exists() {
351                                let metadata = std::fs::metadata(&checkpoint_path)?;
352                                if metadata.len() > 0 {
353                                    site_memory.save_all()?;
354                                    info!("Checkpoint saved at episode {} ({} bytes)", episode, metadata.len());
355                                } else {
356                                    warn!("Checkpoint file is empty, may be corrupted");
357                                }
358                            } else {
359                                warn!("Checkpoint file disappeared after save");
360                            }
361                        }
362                        Err(e) => {
363                            warn!("Failed to save checkpoint metadata: {}", e);
364                        }
365                    }
366                }
367                Err(e) => {
368                    warn!("Failed to save model checkpoint: {}", e);
369                }
370            }
371        }
372    }
373
374    pb.finish_with_message("Training completed");
375
376    // Save final model with validation, metadata and with algorithm name
377    let final_path = config.models_dir.join(format!(
378        "final_model_{}.onnx",
379        config.algorithm.to_string().to_lowercase()
380    ));
381
382    let mut hyperparams = std::collections::HashMap::new();
383    hyperparams.insert("learning_rate".to_string(), config.learning_rate);
384    hyperparams.insert("batch_size".to_string(), config.batch_size as f64);
385    hyperparams.insert("gamma".to_string(), config.gamma);
386    hyperparams.insert("epsilon_decay".to_string(), config.epsilon_decay);
387    hyperparams.insert("target_update_freq".to_string(), config.target_update_freq as f64);
388    agent.save_with_metadata(&final_path, config.num_episodes, hyperparams)?;
389
390    // Verify final save
391    if final_path.exists() {
392        let metadata = std::fs::metadata(&final_path)?;
393        info!("Final model saved: {} bytes", metadata.len());
394    }
395
396    // Display metadata
397    if let Ok(model_meta) = crate::models::ModelMetadata::load_metadata(&final_path) {
398        model_meta.display();
399    }
400
401    site_memory.save_all()?;
402
403    // Save final checkpoint with algorithm-specific path
404    let final_checkpoint = Checkpoint::new(
405        config.num_episodes,
406        agent.get_step_count(),
407        metrics.episode_rewards.last().copied().unwrap_or(0.0),
408        metrics.episode_qualities.last().copied().unwrap_or(0.0),
409        metrics.best_avg_quality,
410        epsilon as f32,
411        final_path,
412    );
413    checkpoint_manager.save_checkpoint(&final_checkpoint)?;
414
415    info!("Training completed. Best avg quality: {:.3}", metrics.best_avg_quality);
416
417    Ok((agent, metrics))
418}
419
420
421/// Training with improvements (curriculum learning, improved rewards, domain extraction, etc.)
422pub fn train_with_improvements(
423    config: &Config,
424    html_samples: Vec<TrainingSample>,
425) -> Result<(Box<dyn RLAgent>, TrainingMetrics)> {
426    info!("Starting OPTIMIZED training for {} episodes", config.num_episodes);
427    info!("Performance settings:");
428    info!("  - Batch size: {}", config.batch_size);
429    info!("  - Train frequency: every {} steps", config.train_freq);
430    info!("  - Gradient updates per episode: {}", config.num_train_steps_per_episode);
431    info!("  - Min replay size: {}", config.min_replay_size);
432    info!("  - Metrics window: {}", config.metrics_window);
433    info!("  - Dataset size: {}", html_samples.len());
434
435    // Initialize device and varbuilder
436    let device = if config.use_cpu_for_tuning {
437        Device::Cpu  // Force CPU for hyperparameter tuning
438    } else if crate::cuda_is_available() {
439        Device::cuda_if_available(0).unwrap_or(Device::Cpu)
440    } else {
441        Device::Cpu
442    };
443
444    // step counters:
445    let mut global_step:usize = 0;
446    let mut total_training_steps:usize = 0;
447
448    // Initialize components
449    let baseline_extractor = BaselineExtractor::new(config.stopwords.clone());
450    let mut site_memory = SiteProfileMemory::new(&config.site_profiles_dir)?;
451    let mut replay_buffer = PrioritizedReplayBuffer::new(
452        config.replay_buffer_size,
453        config.priority_alpha,
454        config.priority_beta,
455    );
456
457    // varmap is created internally by AgentFactory
458    let mut agent = AgentFactory::create(
459        config.algorithm,
460        config.state_dim,
461        config.num_discrete_actions,
462        config.num_continuous_params,
463        config.gamma as f32,
464        config.learning_rate,
465        &device,
466    )?;
467
468    let mut env = ArticleExtractionEnvironment::new(baseline_extractor.clone(), config.clone());
469    let mut metrics = TrainingMetrics { episode_rewards: vec![], episode_qualities: vec![], episode_losses: vec![], best_avg_quality: 0.0 };
470
471    // Enhanced components
472    let mut curriculum = CurriculumManager::new();
473    let mut epsilon = config.epsilon_start;
474
475    // ADDED: Checkpoint manager for improved training
476    let checkpoint_dir = config.models_dir.join("checkpoints");
477    let checkpoint_manager = CheckpointManager::new(checkpoint_dir, 5)?;
478
479    // ADDED: Resume logic similar to train_standard
480    let start_episode = match checkpoint_manager.load_latest() {
481        Ok(Some(checkpoint)) => {
482            // Validate checkpoint compatibility
483            if checkpoint.episode >= config.num_episodes {
484                warn!(
485                    "Found checkpoint at episode {} but current run is only {} episodes. Starting fresh.",
486                    checkpoint.episode, config.num_episodes
487                );
488                0
489            } else if !checkpoint.model_path.exists() {
490                warn!(
491                    "Checkpoint references missing model file: {}. Starting fresh.",
492                    checkpoint.model_path.display()
493                );
494                0
495            } else {
496                // Try to load the model
497                info!("Found checkpoint at episode {}, attempting to load...", checkpoint.episode);
498
499                match AgentFactory::load(
500                    &checkpoint.model_path,
501                    config.state_dim,
502                    config.num_discrete_actions,
503                    config.num_continuous_params,
504                    &device,
505                ) {
506                    Ok(loaded_agent) => {
507                        agent = loaded_agent;
508                        epsilon = checkpoint.epsilon as f64;
509                        metrics.best_avg_quality = checkpoint.best_quality;
510
511                        // Try to load step counts from a separate file
512                        let step_counts_path = checkpoint.model_path.with_extension("steps.json");
513                        if step_counts_path.exists() {
514                            if let Ok(step_data) = std::fs::read_to_string(&step_counts_path) {
515                                if let Ok(step_counts) = serde_json::from_str::<(usize, usize)>(&step_data) {
516                                    global_step = step_counts.0;
517                                    total_training_steps = step_counts.1;
518                                    info!("Resumed step counts: global_step={}, total_training_steps={}",
519                                          global_step, total_training_steps);
520                                }
521                            }
522                        }
523
524                        info!("Successfully resumed from checkpoint at episode {}", checkpoint.episode);
525                        checkpoint.episode
526                    }
527                    Err(e) => {
528                        warn!("Failed to load checkpoint model: {}. Starting fresh.", e);
529                        warn!("Consider deleting checkpoint directory if corruption persists.");
530                        0
531                    }
532                }
533            }
534        }
535        Ok(None) => {
536            info!("No checkpoint found, starting fresh training");
537            0
538        }
539        Err(e) => {
540            warn!("Error loading checkpoint: {}. Starting fresh.", e);
541            0
542        }
543    };
544
545    // Progress bar - start from resume point
546    let pb = ProgressBar::new((config.num_episodes - start_episode) as u64);
547    pb.set_style(
548        ProgressStyle::default_bar()
549            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}")
550            .unwrap()
551            .progress_chars("█▓▒░"),
552    );
553
554    for episode in start_episode..config.num_episodes {
555        let mut _episode_training_steps:usize = 0;
556        // Update curriculum
557        curriculum.update_threshold(episode);
558
559        // Sample HTML (with curriculum filtering)
560        let appropriate_samples: Vec<_> = html_samples.iter()
561            .filter(|s| curriculum.is_appropriate(&s.html))
562            .collect();
563
564        if appropriate_samples.is_empty() {
565            warn!("No appropriate HTML samples for current curriculum");
566            break;
567        }
568
569        let idx = rand::rng().random_range(0..appropriate_samples.len());
570        let sample = appropriate_samples[idx];
571        let html = &sample.html;
572        let file_path = &sample.url;
573
574        // Extract domain from ground truth JSON
575        let domain = extract_domain_from_url(file_path);
576
577        // Log domain extraction (first few episodes for verification)
578        if episode < 10 {
579            info!("Episode {}: File: {}, Domain: {}", episode, file_path, domain);
580        }
581
582        let site_profile = site_memory.get_profile(&domain);
583
584        // Reset environment with ground-truth text so the reward is token F1
585        // against the labelled article (falls back to a quality proxy if absent).
586        let mut state = env.reset(
587            html,
588            file_path.clone(),
589            sample.ground_truth_text.as_deref(),
590            Some(site_profile),
591        )?;
592
593        let mut episode_reward = 0.0;
594        let mut done = false;
595        let mut step_info = StepInfo {
596            quality_score: 0.0,
597            text: String::new(),
598            xpath: String::new(),
599            parameters: std::collections::HashMap::new(),
600            step_count: 0,
601        };
602
603        // Episode loop
604        while !done {
605            let action = agent.select_action(&state, epsilon as f32)?;
606            // The environment now computes the reward as token F1 against the
607            // ground-truth article (or a quality proxy when no GT is present),
608            // so the selected node and tuned params actually drive the signal.
609            let (next_state, reward, is_done, info) = env.step(action.clone())?;
610
611            episode_reward += reward;
612            done = is_done;
613            step_info = info;
614            // Increment global step counter
615            global_step += 1;
616
617            // Store experience
618            let experience = crate::replay_buffer::Experience {
619                state: state.clone(),
620                action,
621                reward,
622                next_state: next_state.clone(),
623                done,
624            };
625            replay_buffer.add(experience);
626
627            // OPTIMIZED: More frequent training after warmup
628            if replay_buffer.len() >= config.min_replay_size &&
629                global_step.is_multiple_of(config.train_freq) {
630                // ADDED: Robust error handling for training step
631                match agent.train_step(&mut replay_buffer, config.batch_size) {
632                    Ok(loss) => {
633                        // Check for NaN or infinite loss
634                        if loss.is_nan() || loss.is_infinite() {
635                            warn!("Invalid loss detected at episode {}, step {}: {}", episode, global_step, loss);
636                            warn!("Skipping this training step");
637                        } else {
638                            metrics.episode_losses.push(loss);
639                            _episode_training_steps += 1;
640                        }
641                    }
642                    Err(e) => {
643                        warn!("Training step failed at episode {}, step {}: {}", episode, global_step, e);
644                        warn!("Continuing training...");
645                        // Don't fail the entire run for one bad batch
646                    }
647                }
648            }
649
650            state = next_state;
651        }
652
653        // OPTIMIZED: Multiple gradient updates per episode
654        if replay_buffer.len() >= config.min_replay_size {
655            for update_idx in 0..config.num_train_steps_per_episode {
656                match agent.train_step(&mut replay_buffer, config.batch_size) {
657                    Ok(loss) => {
658                        if loss.is_nan() || loss.is_infinite() {
659                            warn!("Invalid loss in gradient update {} at episode {}", update_idx, episode);
660                            break; // Stop further updates this episode
661                        }
662                        metrics.episode_losses.push(loss);
663                        total_training_steps += 1;
664                    }
665                    Err(e) => {
666                        warn!("Gradient update {} failed at episode {}: {}", update_idx, episode, e);
667                        break; // Stop further updates this episode
668                    }
669                }
670            }
671        }
672
673        // Update site profile with correct domain
674        let profile = site_memory.get_profile(&domain);
675        let extraction_result = crate::site_profile::ExtractionResult {
676            text: step_info.text.clone(),
677            xpath: step_info.xpath.clone(),
678            quality_score: step_info.quality_score,
679            parameters: step_info.parameters.clone(),
680            title: None,
681            date: None,
682        };
683        profile.add_extraction(extraction_result);
684        // Save site profiles periodically
685        if episode % 100 == 0 && episode > 0 {
686            match site_memory.save_all() {
687                Ok(_) => {
688                    if episode % 500 == 0 {
689                        info!("Site profiles saved at episode {}", episode);
690                    }
691                }
692                Err(e) => {
693                    warn!("Failed to save site profiles: {}", e);
694                }
695            }
696        }
697
698        // Decay epsilon (exponential)
699        let progress = (episode as f32 / 2000.0).min(1.0);
700        epsilon = config.epsilon_start * (config.epsilon_end / config.epsilon_start).powf(progress as f64);
701        epsilon = epsilon.max(config.epsilon_end);
702
703        // Update target network
704        if episode % config.target_update_freq == 0 {
705            agent.update_target_network();
706        }
707
708        // Record metrics
709        metrics.episode_rewards.push(episode_reward);
710        metrics.episode_qualities.push(step_info.quality_score);
711
712        // Update progress bar
713        if episode % config.log_freq == 0 {
714            let window = config.metrics_window;
715            let avg_reward = if metrics.episode_rewards.len() >= window {
716                metrics.episode_rewards[metrics.episode_rewards.len() - window..]
717                    .iter()
718                    .sum::<f32>() / window as f32
719            } else if !metrics.episode_rewards.is_empty() {
720                metrics.episode_rewards.iter().sum::<f32>() / metrics.episode_rewards.len() as f32
721            } else {
722                0.0
723            };
724
725            let avg_quality = if metrics.episode_qualities.len() >= window {
726                metrics.episode_qualities[metrics.episode_qualities.len() - window..]
727                    .iter()
728                    .sum::<f32>() / window as f32
729            } else if !metrics.episode_qualities.is_empty() {
730                metrics.episode_qualities.iter().sum::<f32>() / metrics.episode_qualities.len() as f32
731            } else {
732                0.0
733            };
734
735            let curriculum_threshold = curriculum.get_threshold();
736            pb.set_message(format!(
737                "R:{:.2} Q:{:.3} ε:{:.3} C:{:.2} Steps:{}",
738                avg_reward, avg_quality, epsilon, curriculum_threshold, total_training_steps
739            ));
740        }
741        pb.inc(1);
742
743        // Save checkpoint every 500 episodes (more frequent for safety)
744        if episode % config.checkpoint_freq == 0 && episode > 0 {
745            let checkpoint_path = config.models_dir.join(format!(
746                "checkpoint_{}_{}_ep{}.onnx",
747                config.algorithm.to_string().to_lowercase(),
748                chrono::Utc::now().format("%Y%m%d_%H%M%S"),
749                episode
750            ));
751
752            match agent.save(&checkpoint_path) {
753                Ok(_) => {
754                    // Save step counts alongside model
755                    let step_counts_path = checkpoint_path.with_extension("steps.json");
756                    let step_counts = (global_step, total_training_steps);
757                    if let Ok(step_data) = serde_json::to_string(&step_counts) {
758                        let _ = std::fs::write(&step_counts_path, step_data);
759                    }
760
761                    let avg_reward = if metrics.episode_rewards.len() >= 100 {
762                        metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
763                            .iter()
764                            .sum::<f32>() / 100.0
765                    } else {
766                        0.0
767                    };
768
769                    let avg_quality = if metrics.episode_qualities.len() >= 100 {
770                        metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
771                            .iter()
772                            .sum::<f32>() / 100.0
773                    } else {
774                        0.0
775                    };
776
777                    let checkpoint = Checkpoint::new(
778                        episode,
779                        total_training_steps,
780                        avg_reward,
781                        avg_quality,
782                        metrics.best_avg_quality,
783                        epsilon as f32,
784                        checkpoint_path.clone(),
785                    );
786
787                    match checkpoint_manager.save_checkpoint(&checkpoint) {
788                        Ok(_) => {
789                            site_memory.save_all()?;
790                            info!("Improved training checkpoint saved at episode {} (global_step={})",
791                                  episode, global_step);
792                        }
793                        Err(e) => {
794                            warn!("Failed to save checkpoint metadata: {}", e);
795                        }
796                    }
797
798                    if let Ok(metadata) = std::fs::metadata(&checkpoint_path) {
799                        let file_size = metadata.len();
800                        if file_size < 10_000 {
801                            warn!("Checkpoint file suspiciously small: {} bytes", file_size);
802                        } else {
803                            info!("Checkpoint saved at episode {} ({} bytes)", episode, file_size);
804                        }
805                    }
806                }
807                Err(e) => {
808                    warn!("Failed to save model checkpoint: {}", e);
809                }
810            }
811        }
812
813        // Track best model with algorithm-specific name
814        if metrics.episode_qualities.len() >= 100 {
815            let avg_quality = metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
816                .iter()
817                .sum::<f32>() / 100.0;
818
819            if avg_quality > metrics.best_avg_quality {
820                metrics.best_avg_quality = avg_quality;
821                let best_path = config.models_dir.join(format!(
822                    "best_model_{}.onnx",
823                    config.algorithm.to_string().to_lowercase()
824                ));
825
826                match agent.save(&best_path) {
827                    Ok(_) => {
828                        if let Ok(metadata) = std::fs::metadata(&best_path) {
829                            info!("New best {} model saved with quality: {:.3} ({} bytes)",
830                                  config.algorithm, avg_quality, metadata.len());
831                        } else {
832                            info!("New best {} model saved with quality: {:.3}",
833                                  config.algorithm, avg_quality);
834                        }
835                    }
836                    Err(e) => {
837                        warn!("Failed to save best model: {}", e);
838                    }
839                }
840            }
841        }
842    }
843
844    pb.finish_with_message("Improved training completed");
845
846    // Save final model with validation, metadata and algorithm name
847    let final_path = config.models_dir.join(format!(
848        "final_model_{}.onnx",
849        config.algorithm.to_string().to_lowercase()
850    ));
851
852    let mut hyperparams = std::collections::HashMap::new();
853    hyperparams.insert("learning_rate".to_string(), config.learning_rate);
854    hyperparams.insert("batch_size".to_string(), config.batch_size as f64);
855    hyperparams.insert("gamma".to_string(), config.gamma);
856    hyperparams.insert("epsilon_decay".to_string(), config.epsilon_decay);
857    hyperparams.insert("target_update_freq".to_string(), config.target_update_freq as f64);
858    agent.save_with_metadata(&final_path, config.num_episodes, hyperparams)?;
859
860    // Verify final save
861    if final_path.exists() {
862        let metadata = std::fs::metadata(&final_path)?;
863        info!("Final model saved: {} bytes", metadata.len());
864    }
865
866    // Display metadata
867    if let Ok(model_meta) = crate::models::ModelMetadata::load_metadata(&final_path) {
868        model_meta.display();
869    }
870
871    site_memory.save_all()?;
872
873    // Save final checkpoint
874    let final_checkpoint = Checkpoint::new(
875        config.num_episodes,
876        total_training_steps,
877        metrics.episode_rewards.last().copied().unwrap_or(0.0),
878        metrics.episode_qualities.last().copied().unwrap_or(0.0),
879        metrics.best_avg_quality,
880        epsilon as f32,
881        final_path,
882    );
883    checkpoint_manager.save_checkpoint(&final_checkpoint)?;
884
885    info!("Training completed:");
886    info!("  - Total episodes: {}", config.num_episodes);
887    info!("  - Total training steps: {}", total_training_steps);
888    info!("  - Best avg quality: {:.3}", metrics.best_avg_quality);
889    info!("  - Final epsilon: {:.3}", epsilon);
890
891    Ok((agent, metrics))
892}
893
894/// Save training plot
895pub fn save_training_plot(metrics: &TrainingMetrics, output_path: &Path) -> Result<()> {
896    use plotters::prelude::*;
897
898    let root = BitMapBackend::new(output_path, (1200, 800))
899        .into_drawing_area();
900    root.fill(&WHITE)
901        .map_err(|e| ExtractionError::ModelError(format!("Plot error: {}", e)))?;
902
903    let areas = root.split_evenly((2, 1));
904    let upper = &areas[0];
905    let lower = &areas[1];
906
907    // Plot rewards
908    let max_episodes = metrics.episode_rewards.len();
909    let max_reward = metrics.episode_rewards.iter()
910        .copied()
911        .fold(f32::NEG_INFINITY, f32::max);
912    let min_reward = metrics.episode_rewards.iter()
913        .copied()
914        .fold(f32::INFINITY, f32::min);
915
916    let mut chart = ChartBuilder::on(upper)
917        .caption("Episode Rewards", ("sans-serif", 30))
918        .margin(10)
919        .x_label_area_size(30)
920        .y_label_area_size(40)
921        .build_cartesian_2d(0..max_episodes, min_reward..max_reward)
922        .map_err(|e| ExtractionError::ModelError(format!("Chart error: {}", e)))?;
923
924    chart.configure_mesh()
925        .draw()
926        .map_err(|e| ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
927
928    chart.draw_series(LineSeries::new(
929        metrics.episode_rewards.iter().enumerate().map(|(i, &r)| (i, r)),
930        &BLUE,
931    ))
932        .map_err(|e| ExtractionError::ModelError(format!("Series error: {}", e)))?;
933
934    // Plot qualities
935    let max_quality = metrics.episode_qualities.iter()
936        .copied()
937        .fold(f32::NEG_INFINITY, f32::max);
938
939    let mut chart2 = ChartBuilder::on(lower)
940        .caption("Episode Quality", ("sans-serif", 30))
941        .margin(10)
942        .x_label_area_size(30)
943        .y_label_area_size(40)
944        .build_cartesian_2d(0..max_episodes, 0.0..max_quality)
945        .map_err(|e| ExtractionError::ModelError(format!("Chart error: {}", e)))?;
946
947    chart2.configure_mesh()
948        .draw()
949        .map_err(|e| ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
950
951    chart2.draw_series(LineSeries::new(
952        metrics.episode_qualities.iter().enumerate().map(|(i, &q)| (i, q)),
953        &GREEN,
954    ))
955        .map_err(|e| ExtractionError::ModelError(format!("Series error: {}", e)))?;
956
957    root.present().map_err(|e| crate::ExtractionError::IoError(
958        std::io::Error::other(e.to_string())
959    ))?;
960
961    info!("Training plot saved to: {}", output_path.display());
962
963    Ok(())
964}