Skip to main content

content_extractor_rl/evaluation/
algorithm_comparison.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/evaluation/algorithm_comparison.rs
3// ============================================================================
4
5use crate::{
6    Config, Result, agents::AlgorithmType,
7    training::{train_standard},
8};
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use tracing::info;
12
13/// Results for a single algorithm
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AlgorithmResult {
16    pub algorithm: AlgorithmType,
17    pub runs: Vec<RunResult>,
18    pub avg_quality: f64,
19    pub std_quality: f64,
20    pub avg_reward: f64,
21    pub std_reward: f64,
22    pub avg_training_time: f64,
23}
24
25/// Results for a single run
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct RunResult {
28    pub run_number: usize,
29    pub final_quality: f32,
30    pub final_reward: f32,
31    pub avg_quality_last_100: f32,
32    pub avg_reward_last_100: f32,
33    pub training_time_seconds: f64,
34}
35
36/// Comparison report across all algorithms
37#[derive(Debug, Serialize, Deserialize)]
38pub struct ComparisonReport {
39    pub algorithms: Vec<AlgorithmResult>,
40    pub best_by_quality: String,
41    pub best_by_reward: String,
42    pub best_by_time: String,
43    pub config: ComparisonConfig,
44}
45
46/// Configuration for comparison
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ComparisonConfig {
49    pub episodes: usize,
50    pub runs: usize,
51    pub dataset_size: usize,
52}
53
54/// Algorithm comparator
55pub struct AlgorithmComparator {
56    config: Config,
57    output_dir: PathBuf,
58}
59
60impl AlgorithmComparator {
61    pub fn new(config: Config, output_dir: PathBuf) -> Result<Self> {
62        std::fs::create_dir_all(&output_dir)?;
63        Ok(Self { config, output_dir })
64    }
65
66    /// Compare multiple algorithms
67    pub fn compare_algorithms(
68        &self,
69        algorithms: Vec<AlgorithmType>,
70        html_samples: Vec<(String, String)>,
71        episodes: usize,
72        runs: usize,
73    ) -> Result<ComparisonReport> {
74        info!("Starting algorithm comparison");
75        info!("Algorithms: {:?}", algorithms);
76        info!("Episodes: {}, Runs per algorithm: {}", episodes, runs);
77
78        let mut results = Vec::new();
79
80        for algorithm in algorithms {
81            info!("Evaluating algorithm: {}", algorithm);
82            let algo_result = self.evaluate_algorithm(
83                algorithm,
84                html_samples.clone(),
85                episodes,
86                runs,
87            )?;
88            results.push(algo_result);
89        }
90
91        // Find best algorithms
92        let best_by_quality = results.iter()
93            .max_by(|a, b| a.avg_quality.partial_cmp(&b.avg_quality).unwrap())
94            .map(|r| r.algorithm.to_string())
95            .unwrap_or_else(|| "None".to_string());
96
97        let best_by_reward = results.iter()
98            .max_by(|a, b| a.avg_reward.partial_cmp(&b.avg_reward).unwrap())
99            .map(|r| r.algorithm.to_string())
100            .unwrap_or_else(|| "None".to_string());
101
102        let best_by_time = results.iter()
103            .min_by(|a, b| a.avg_training_time.partial_cmp(&b.avg_training_time).unwrap())
104            .map(|r| r.algorithm.to_string())
105            .unwrap_or_else(|| "None".to_string());
106
107        let report = ComparisonReport {
108            algorithms: results,
109            best_by_quality: best_by_quality.clone(),
110            best_by_reward: best_by_reward.clone(),
111            best_by_time: best_by_time.clone(),
112            config: ComparisonConfig {
113                episodes,
114                runs,
115                dataset_size: html_samples.len(),
116            },
117        };
118
119        // Save report
120        self.save_report(&report)?;
121
122        // Print summary
123        self.print_summary(&report);
124
125        Ok(report)
126    }
127
128    /// Evaluate a single algorithm with multiple runs
129    fn evaluate_algorithm(
130        &self,
131        algorithm: AlgorithmType,
132        html_samples: Vec<(String, String)>,
133        episodes: usize,
134        runs: usize,
135    ) -> Result<AlgorithmResult> {
136        let mut run_results = Vec::new();
137
138        for run in 0..runs {
139            info!("Algorithm: {}, Run: {}/{}", algorithm, run + 1, runs);
140
141            let start_time = std::time::Instant::now();
142
143            // Create config for this run
144            let mut run_config = self.config.clone();
145            run_config.algorithm = algorithm;
146            run_config.num_episodes = episodes;
147
148            // Train
149            let (_agent, metrics) = train_standard(&run_config, html_samples.clone())?;
150
151            let training_time = start_time.elapsed().as_secs_f64();
152
153            // Calculate metrics
154            let final_quality = metrics.episode_qualities.last().copied().unwrap_or(0.0);
155            let final_reward = metrics.episode_rewards.last().copied().unwrap_or(0.0);
156
157            let avg_quality_last_100 = if metrics.episode_qualities.len() >= 100 {
158                metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
159                    .iter()
160                    .sum::<f32>() / 100.0
161            } else if !metrics.episode_qualities.is_empty() {
162                metrics.episode_qualities.iter().sum::<f32>() / metrics.episode_qualities.len() as f32
163            } else {
164                0.0
165            };
166
167            let avg_reward_last_100 = if metrics.episode_rewards.len() >= 100 {
168                metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
169                    .iter()
170                    .sum::<f32>() / 100.0
171            } else if !metrics.episode_rewards.is_empty() {
172                metrics.episode_rewards.iter().sum::<f32>() / metrics.episode_rewards.len() as f32
173            } else {
174                0.0
175            };
176
177            let run_result = RunResult {
178                run_number: run,
179                final_quality,
180                final_reward,
181                avg_quality_last_100,
182                avg_reward_last_100,
183                training_time_seconds: training_time,
184            };
185
186            run_results.push(run_result);
187
188            info!("Run {} complete: quality={:.4}, reward={:.4}, time={:.2}s",
189                  run + 1, avg_quality_last_100, avg_reward_last_100, training_time);
190        }
191
192        // Calculate statistics
193        let avg_quality = run_results.iter()
194            .map(|r| r.avg_quality_last_100 as f64)
195            .sum::<f64>() / runs as f64;
196
197        let std_quality = {
198            let variance = run_results.iter()
199                .map(|r| {
200                    let diff = r.avg_quality_last_100 as f64 - avg_quality;
201                    diff * diff
202                })
203                .sum::<f64>() / runs as f64;
204            variance.sqrt()
205        };
206
207        let avg_reward = run_results.iter()
208            .map(|r| r.avg_reward_last_100 as f64)
209            .sum::<f64>() / runs as f64;
210
211        let std_reward = {
212            let variance = run_results.iter()
213                .map(|r| {
214                    let diff = r.avg_reward_last_100 as f64 - avg_reward;
215                    diff * diff
216                })
217                .sum::<f64>() / runs as f64;
218            variance.sqrt()
219        };
220
221        let avg_training_time = run_results.iter()
222            .map(|r| r.training_time_seconds)
223            .sum::<f64>() / runs as f64;
224
225        Ok(AlgorithmResult {
226            algorithm,
227            runs: run_results,
228            avg_quality,
229            std_quality,
230            avg_reward,
231            std_reward,
232            avg_training_time,
233        })
234    }
235
236    /// Save comparison report
237    fn save_report(&self, report: &ComparisonReport) -> Result<()> {
238        let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
239        let path = self.output_dir.join(format!("comparison_report_{}.json", timestamp));
240
241        let json = serde_json::to_string_pretty(report)?;
242        std::fs::write(&path, json)?;
243
244        info!("Comparison report saved to: {}", path.display());
245        Ok(())
246    }
247
248    /// Print summary to console
249    fn print_summary(&self, report: &ComparisonReport) {
250        println!("\n{}", "=".repeat(80));
251        println!("ALGORITHM COMPARISON RESULTS");
252        println!("{}", "=".repeat(80));
253        println!("Episodes: {}, Runs per algorithm: {}",
254                 report.config.episodes, report.config.runs);
255        println!("Dataset size: {}", report.config.dataset_size);
256        println!("{}", "=".repeat(80));
257
258        for result in &report.algorithms {
259            println!("\nAlgorithm: {}", result.algorithm);
260            println!("  Average Quality:  {:.4} ± {:.4}", result.avg_quality, result.std_quality);
261            println!("  Average Reward:   {:.4} ± {:.4}", result.avg_reward, result.std_reward);
262            println!("  Average Time:     {:.2}s", result.avg_training_time);
263            println!("  Individual runs:");
264            for run in &result.runs {
265                println!("    Run {}: quality={:.4}, reward={:.4}, time={:.2}s",
266                         run.run_number + 1,
267                         run.avg_quality_last_100,
268                         run.avg_reward_last_100,
269                         run.training_time_seconds);
270            }
271        }
272
273        println!("\n{}", "=".repeat(80));
274        println!("WINNERS");
275        println!("{}", "=".repeat(80));
276        println!("Best Quality:  {}", report.best_by_quality);
277        println!("Best Reward:   {}", report.best_by_reward);
278        println!("Fastest:       {}", report.best_by_time);
279        println!("{}", "=".repeat(80));
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use tempfile::TempDir;
287
288    #[test]
289    fn test_comparison_report_creation() {
290        let run_result = RunResult {
291            run_number: 0,
292            final_quality: 0.8,
293            final_reward: 10.0,
294            avg_quality_last_100: 0.75,
295            avg_reward_last_100: 9.5,
296            training_time_seconds: 100.0,
297        };
298
299        let algo_result = AlgorithmResult {
300            algorithm: AlgorithmType::DuelingDQN,
301            runs: vec![run_result],
302            avg_quality: 0.75,
303            std_quality: 0.05,
304            avg_reward: 9.5,
305            std_reward: 0.5,
306            avg_training_time: 100.0,
307        };
308
309        assert_eq!(algo_result.runs.len(), 1);
310        assert!((algo_result.avg_quality - 0.75).abs() < 0.01);
311    }
312}