content_extractor_rl/evaluation/
algorithm_comparison.rs1use crate::{
6 Config, Result, agents::AlgorithmType,
7 training::{train_standard},
8};
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use tracing::info;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AlgorithmResult {
16 pub algorithm: AlgorithmType,
17 pub runs: Vec<RunResult>,
18 pub avg_quality: f64,
19 pub std_quality: f64,
20 pub avg_reward: f64,
21 pub std_reward: f64,
22 pub avg_training_time: f64,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct RunResult {
28 pub run_number: usize,
29 pub final_quality: f32,
30 pub final_reward: f32,
31 pub avg_quality_last_100: f32,
32 pub avg_reward_last_100: f32,
33 pub training_time_seconds: f64,
34}
35
36#[derive(Debug, Serialize, Deserialize)]
38pub struct ComparisonReport {
39 pub algorithms: Vec<AlgorithmResult>,
40 pub best_by_quality: String,
41 pub best_by_reward: String,
42 pub best_by_time: String,
43 pub config: ComparisonConfig,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ComparisonConfig {
49 pub episodes: usize,
50 pub runs: usize,
51 pub dataset_size: usize,
52}
53
54pub struct AlgorithmComparator {
56 config: Config,
57 output_dir: PathBuf,
58}
59
60impl AlgorithmComparator {
61 pub fn new(config: Config, output_dir: PathBuf) -> Result<Self> {
62 std::fs::create_dir_all(&output_dir)?;
63 Ok(Self { config, output_dir })
64 }
65
66 pub fn compare_algorithms(
68 &self,
69 algorithms: Vec<AlgorithmType>,
70 html_samples: Vec<(String, String)>,
71 episodes: usize,
72 runs: usize,
73 ) -> Result<ComparisonReport> {
74 info!("Starting algorithm comparison");
75 info!("Algorithms: {:?}", algorithms);
76 info!("Episodes: {}, Runs per algorithm: {}", episodes, runs);
77
78 let mut results = Vec::new();
79
80 for algorithm in algorithms {
81 info!("Evaluating algorithm: {}", algorithm);
82 let algo_result = self.evaluate_algorithm(
83 algorithm,
84 html_samples.clone(),
85 episodes,
86 runs,
87 )?;
88 results.push(algo_result);
89 }
90
91 let best_by_quality = results.iter()
93 .max_by(|a, b| a.avg_quality.partial_cmp(&b.avg_quality).unwrap())
94 .map(|r| r.algorithm.to_string())
95 .unwrap_or_else(|| "None".to_string());
96
97 let best_by_reward = results.iter()
98 .max_by(|a, b| a.avg_reward.partial_cmp(&b.avg_reward).unwrap())
99 .map(|r| r.algorithm.to_string())
100 .unwrap_or_else(|| "None".to_string());
101
102 let best_by_time = results.iter()
103 .min_by(|a, b| a.avg_training_time.partial_cmp(&b.avg_training_time).unwrap())
104 .map(|r| r.algorithm.to_string())
105 .unwrap_or_else(|| "None".to_string());
106
107 let report = ComparisonReport {
108 algorithms: results,
109 best_by_quality: best_by_quality.clone(),
110 best_by_reward: best_by_reward.clone(),
111 best_by_time: best_by_time.clone(),
112 config: ComparisonConfig {
113 episodes,
114 runs,
115 dataset_size: html_samples.len(),
116 },
117 };
118
119 self.save_report(&report)?;
121
122 self.print_summary(&report);
124
125 Ok(report)
126 }
127
128 fn evaluate_algorithm(
130 &self,
131 algorithm: AlgorithmType,
132 html_samples: Vec<(String, String)>,
133 episodes: usize,
134 runs: usize,
135 ) -> Result<AlgorithmResult> {
136 let mut run_results = Vec::new();
137
138 for run in 0..runs {
139 info!("Algorithm: {}, Run: {}/{}", algorithm, run + 1, runs);
140
141 let start_time = std::time::Instant::now();
142
143 let mut run_config = self.config.clone();
145 run_config.algorithm = algorithm;
146 run_config.num_episodes = episodes;
147
148 let (_agent, metrics) = train_standard(&run_config, html_samples.clone())?;
150
151 let training_time = start_time.elapsed().as_secs_f64();
152
153 let final_quality = metrics.episode_qualities.last().copied().unwrap_or(0.0);
155 let final_reward = metrics.episode_rewards.last().copied().unwrap_or(0.0);
156
157 let avg_quality_last_100 = if metrics.episode_qualities.len() >= 100 {
158 metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
159 .iter()
160 .sum::<f32>() / 100.0
161 } else if !metrics.episode_qualities.is_empty() {
162 metrics.episode_qualities.iter().sum::<f32>() / metrics.episode_qualities.len() as f32
163 } else {
164 0.0
165 };
166
167 let avg_reward_last_100 = if metrics.episode_rewards.len() >= 100 {
168 metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
169 .iter()
170 .sum::<f32>() / 100.0
171 } else if !metrics.episode_rewards.is_empty() {
172 metrics.episode_rewards.iter().sum::<f32>() / metrics.episode_rewards.len() as f32
173 } else {
174 0.0
175 };
176
177 let run_result = RunResult {
178 run_number: run,
179 final_quality,
180 final_reward,
181 avg_quality_last_100,
182 avg_reward_last_100,
183 training_time_seconds: training_time,
184 };
185
186 run_results.push(run_result);
187
188 info!("Run {} complete: quality={:.4}, reward={:.4}, time={:.2}s",
189 run + 1, avg_quality_last_100, avg_reward_last_100, training_time);
190 }
191
192 let avg_quality = run_results.iter()
194 .map(|r| r.avg_quality_last_100 as f64)
195 .sum::<f64>() / runs as f64;
196
197 let std_quality = {
198 let variance = run_results.iter()
199 .map(|r| {
200 let diff = r.avg_quality_last_100 as f64 - avg_quality;
201 diff * diff
202 })
203 .sum::<f64>() / runs as f64;
204 variance.sqrt()
205 };
206
207 let avg_reward = run_results.iter()
208 .map(|r| r.avg_reward_last_100 as f64)
209 .sum::<f64>() / runs as f64;
210
211 let std_reward = {
212 let variance = run_results.iter()
213 .map(|r| {
214 let diff = r.avg_reward_last_100 as f64 - avg_reward;
215 diff * diff
216 })
217 .sum::<f64>() / runs as f64;
218 variance.sqrt()
219 };
220
221 let avg_training_time = run_results.iter()
222 .map(|r| r.training_time_seconds)
223 .sum::<f64>() / runs as f64;
224
225 Ok(AlgorithmResult {
226 algorithm,
227 runs: run_results,
228 avg_quality,
229 std_quality,
230 avg_reward,
231 std_reward,
232 avg_training_time,
233 })
234 }
235
236 fn save_report(&self, report: &ComparisonReport) -> Result<()> {
238 let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
239 let path = self.output_dir.join(format!("comparison_report_{}.json", timestamp));
240
241 let json = serde_json::to_string_pretty(report)?;
242 std::fs::write(&path, json)?;
243
244 info!("Comparison report saved to: {}", path.display());
245 Ok(())
246 }
247
248 fn print_summary(&self, report: &ComparisonReport) {
250 println!("\n{}", "=".repeat(80));
251 println!("ALGORITHM COMPARISON RESULTS");
252 println!("{}", "=".repeat(80));
253 println!("Episodes: {}, Runs per algorithm: {}",
254 report.config.episodes, report.config.runs);
255 println!("Dataset size: {}", report.config.dataset_size);
256 println!("{}", "=".repeat(80));
257
258 for result in &report.algorithms {
259 println!("\nAlgorithm: {}", result.algorithm);
260 println!(" Average Quality: {:.4} ± {:.4}", result.avg_quality, result.std_quality);
261 println!(" Average Reward: {:.4} ± {:.4}", result.avg_reward, result.std_reward);
262 println!(" Average Time: {:.2}s", result.avg_training_time);
263 println!(" Individual runs:");
264 for run in &result.runs {
265 println!(" Run {}: quality={:.4}, reward={:.4}, time={:.2}s",
266 run.run_number + 1,
267 run.avg_quality_last_100,
268 run.avg_reward_last_100,
269 run.training_time_seconds);
270 }
271 }
272
273 println!("\n{}", "=".repeat(80));
274 println!("WINNERS");
275 println!("{}", "=".repeat(80));
276 println!("Best Quality: {}", report.best_by_quality);
277 println!("Best Reward: {}", report.best_by_reward);
278 println!("Fastest: {}", report.best_by_time);
279 println!("{}", "=".repeat(80));
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use tempfile::TempDir;
287
288 #[test]
289 fn test_comparison_report_creation() {
290 let run_result = RunResult {
291 run_number: 0,
292 final_quality: 0.8,
293 final_reward: 10.0,
294 avg_quality_last_100: 0.75,
295 avg_reward_last_100: 9.5,
296 training_time_seconds: 100.0,
297 };
298
299 let algo_result = AlgorithmResult {
300 algorithm: AlgorithmType::DuelingDQN,
301 runs: vec![run_result],
302 avg_quality: 0.75,
303 std_quality: 0.05,
304 avg_reward: 9.5,
305 std_reward: 0.5,
306 avg_training_time: 100.0,
307 };
308
309 assert_eq!(algo_result.runs.len(), 1);
310 assert!((algo_result.avg_quality - 0.75).abs() < 0.01);
311 }
312}