Skip to main content

content_extractor_rl/evaluation/
algorithm_comparison.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/evaluation/algorithm_comparison.rs
3// ============================================================================
4
5use crate::{
6    Config, Result, agents::AlgorithmType,
7    training::{train_standard},
8};
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use tracing::info;
12
13/// Results for a single algorithm
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AlgorithmResult {
16    pub algorithm: AlgorithmType,
17    pub runs: Vec<RunResult>,
18    pub avg_quality: f64,
19    pub std_quality: f64,
20    pub avg_reward: f64,
21    pub std_reward: f64,
22    pub avg_training_time: f64,
23}
24
25/// Results for a single run
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct RunResult {
28    pub run_number: usize,
29    pub final_quality: f32,
30    pub final_reward: f32,
31    pub avg_quality_last_100: f32,
32    pub avg_reward_last_100: f32,
33    pub training_time_seconds: f64,
34}
35
36/// Comparison report across all algorithms
37#[derive(Debug, Serialize, Deserialize)]
38pub struct ComparisonReport {
39    pub algorithms: Vec<AlgorithmResult>,
40    pub best_by_quality: String,
41    pub best_by_reward: String,
42    pub best_by_time: String,
43    pub config: ComparisonConfig,
44}
45
46/// Configuration for comparison
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ComparisonConfig {
49    pub episodes: usize,
50    pub runs: usize,
51    pub dataset_size: usize,
52}
53
54/// Algorithm comparator
55pub struct AlgorithmComparator {
56    config: Config,
57    output_dir: PathBuf,
58}
59
60impl AlgorithmComparator {
61    pub fn new(config: Config, output_dir: PathBuf) -> Result<Self> {
62        std::fs::create_dir_all(&output_dir)?;
63        Ok(Self { config, output_dir })
64    }
65
66    /// Compare multiple algorithms
67    pub fn compare_algorithms(
68        &self,
69        algorithms: Vec<AlgorithmType>,
70        html_samples: Vec<(String, String)>,
71        episodes: usize,
72        runs: usize,
73    ) -> Result<ComparisonReport> {
74        info!("Starting algorithm comparison");
75        info!("Algorithms: {:?}", algorithms);
76        info!("Episodes: {}, Runs per algorithm: {}", episodes, runs);
77
78        let mut results = Vec::new();
79
80        for algorithm in algorithms {
81            info!("Evaluating algorithm: {}", algorithm);
82            let algo_result = self.evaluate_algorithm(
83                algorithm,
84                html_samples.clone(),
85                episodes,
86                runs,
87            )?;
88            results.push(algo_result);
89        }
90
91        // Find best algorithms
92        let best_by_quality = results.iter()
93            .max_by(|a, b| a.avg_quality.partial_cmp(&b.avg_quality).unwrap())
94            .map(|r| r.algorithm.to_string())
95            .unwrap_or_else(|| "None".to_string());
96
97        let best_by_reward = results.iter()
98            .max_by(|a, b| a.avg_reward.partial_cmp(&b.avg_reward).unwrap())
99            .map(|r| r.algorithm.to_string())
100            .unwrap_or_else(|| "None".to_string());
101
102        let best_by_time = results.iter()
103            .min_by(|a, b| a.avg_training_time.partial_cmp(&b.avg_training_time).unwrap())
104            .map(|r| r.algorithm.to_string())
105            .unwrap_or_else(|| "None".to_string());
106
107        let report = ComparisonReport {
108            algorithms: results,
109            best_by_quality: best_by_quality.clone(),
110            best_by_reward: best_by_reward.clone(),
111            best_by_time: best_by_time.clone(),
112            config: ComparisonConfig {
113                episodes,
114                runs,
115                dataset_size: html_samples.len(),
116            },
117        };
118
119        // Save report
120        self.save_report(&report)?;
121
122        // Print summary
123        self.print_summary(&report);
124
125        Ok(report)
126    }
127
128    /// Evaluate a single algorithm with multiple runs
129    fn evaluate_algorithm(
130        &self,
131        algorithm: AlgorithmType,
132        html_samples: Vec<(String, String)>,
133        episodes: usize,
134        runs: usize,
135    ) -> Result<AlgorithmResult> {
136        let mut run_results = Vec::new();
137
138        for run in 0..runs {
139            info!("Algorithm: {}, Run: {}/{}", algorithm, run + 1, runs);
140
141            let start_time = std::time::Instant::now();
142
143            // Create config for this run
144            let mut run_config = self.config.clone();
145            run_config.algorithm = algorithm;
146            run_config.num_episodes = episodes;
147
148            // Train
149            let run_samples: Vec<crate::TrainingSample> =
150                html_samples.clone().into_iter().map(Into::into).collect();
151            let (_agent, metrics) = train_standard(&run_config, run_samples)?;
152
153            let training_time = start_time.elapsed().as_secs_f64();
154
155            // Calculate metrics
156            let final_quality = metrics.episode_qualities.last().copied().unwrap_or(0.0);
157            let final_reward = metrics.episode_rewards.last().copied().unwrap_or(0.0);
158
159            let avg_quality_last_100 = if metrics.episode_qualities.len() >= 100 {
160                metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
161                    .iter()
162                    .sum::<f32>() / 100.0
163            } else if !metrics.episode_qualities.is_empty() {
164                metrics.episode_qualities.iter().sum::<f32>() / metrics.episode_qualities.len() as f32
165            } else {
166                0.0
167            };
168
169            let avg_reward_last_100 = if metrics.episode_rewards.len() >= 100 {
170                metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
171                    .iter()
172                    .sum::<f32>() / 100.0
173            } else if !metrics.episode_rewards.is_empty() {
174                metrics.episode_rewards.iter().sum::<f32>() / metrics.episode_rewards.len() as f32
175            } else {
176                0.0
177            };
178
179            let run_result = RunResult {
180                run_number: run,
181                final_quality,
182                final_reward,
183                avg_quality_last_100,
184                avg_reward_last_100,
185                training_time_seconds: training_time,
186            };
187
188            run_results.push(run_result);
189
190            info!("Run {} complete: quality={:.4}, reward={:.4}, time={:.2}s",
191                  run + 1, avg_quality_last_100, avg_reward_last_100, training_time);
192        }
193
194        // Calculate statistics
195        let avg_quality = run_results.iter()
196            .map(|r| r.avg_quality_last_100 as f64)
197            .sum::<f64>() / runs as f64;
198
199        let std_quality = {
200            let variance = run_results.iter()
201                .map(|r| {
202                    let diff = r.avg_quality_last_100 as f64 - avg_quality;
203                    diff * diff
204                })
205                .sum::<f64>() / runs as f64;
206            variance.sqrt()
207        };
208
209        let avg_reward = run_results.iter()
210            .map(|r| r.avg_reward_last_100 as f64)
211            .sum::<f64>() / runs as f64;
212
213        let std_reward = {
214            let variance = run_results.iter()
215                .map(|r| {
216                    let diff = r.avg_reward_last_100 as f64 - avg_reward;
217                    diff * diff
218                })
219                .sum::<f64>() / runs as f64;
220            variance.sqrt()
221        };
222
223        let avg_training_time = run_results.iter()
224            .map(|r| r.training_time_seconds)
225            .sum::<f64>() / runs as f64;
226
227        Ok(AlgorithmResult {
228            algorithm,
229            runs: run_results,
230            avg_quality,
231            std_quality,
232            avg_reward,
233            std_reward,
234            avg_training_time,
235        })
236    }
237
238    /// Save comparison report
239    fn save_report(&self, report: &ComparisonReport) -> Result<()> {
240        let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
241        let path = self.output_dir.join(format!("comparison_report_{}.json", timestamp));
242
243        let json = serde_json::to_string_pretty(report)?;
244        std::fs::write(&path, json)?;
245
246        info!("Comparison report saved to: {}", path.display());
247        Ok(())
248    }
249
250    /// Print summary to console
251    fn print_summary(&self, report: &ComparisonReport) {
252        println!("\n{}", "=".repeat(80));
253        println!("ALGORITHM COMPARISON RESULTS");
254        println!("{}", "=".repeat(80));
255        println!("Episodes: {}, Runs per algorithm: {}",
256                 report.config.episodes, report.config.runs);
257        println!("Dataset size: {}", report.config.dataset_size);
258        println!("{}", "=".repeat(80));
259
260        for result in &report.algorithms {
261            println!("\nAlgorithm: {}", result.algorithm);
262            println!("  Average Quality:  {:.4} ± {:.4}", result.avg_quality, result.std_quality);
263            println!("  Average Reward:   {:.4} ± {:.4}", result.avg_reward, result.std_reward);
264            println!("  Average Time:     {:.2}s", result.avg_training_time);
265            println!("  Individual runs:");
266            for run in &result.runs {
267                println!("    Run {}: quality={:.4}, reward={:.4}, time={:.2}s",
268                         run.run_number + 1,
269                         run.avg_quality_last_100,
270                         run.avg_reward_last_100,
271                         run.training_time_seconds);
272            }
273        }
274
275        println!("\n{}", "=".repeat(80));
276        println!("WINNERS");
277        println!("{}", "=".repeat(80));
278        println!("Best Quality:  {}", report.best_by_quality);
279        println!("Best Reward:   {}", report.best_by_reward);
280        println!("Fastest:       {}", report.best_by_time);
281        println!("{}", "=".repeat(80));
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use tempfile::TempDir;
289
290    #[test]
291    fn test_comparison_report_creation() {
292        let run_result = RunResult {
293            run_number: 0,
294            final_quality: 0.8,
295            final_reward: 10.0,
296            avg_quality_last_100: 0.75,
297            avg_reward_last_100: 9.5,
298            training_time_seconds: 100.0,
299        };
300
301        let algo_result = AlgorithmResult {
302            algorithm: AlgorithmType::DuelingDQN,
303            runs: vec![run_result],
304            avg_quality: 0.75,
305            std_quality: 0.05,
306            avg_reward: 9.5,
307            std_reward: 0.5,
308            avg_training_time: 100.0,
309        };
310
311        assert_eq!(algo_result.runs.len(), 1);
312        assert!((algo_result.avg_quality - 0.75).abs() < 0.01);
313    }
314}