content_extractor_rl/evaluation/
algorithm_comparison.rs1use crate::{
6 Config, Result, agents::AlgorithmType,
7 training::{train_standard},
8};
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use tracing::info;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AlgorithmResult {
16 pub algorithm: AlgorithmType,
17 pub runs: Vec<RunResult>,
18 pub avg_quality: f64,
19 pub std_quality: f64,
20 pub avg_reward: f64,
21 pub std_reward: f64,
22 pub avg_training_time: f64,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct RunResult {
28 pub run_number: usize,
29 pub final_quality: f32,
30 pub final_reward: f32,
31 pub avg_quality_last_100: f32,
32 pub avg_reward_last_100: f32,
33 pub training_time_seconds: f64,
34}
35
36#[derive(Debug, Serialize, Deserialize)]
38pub struct ComparisonReport {
39 pub algorithms: Vec<AlgorithmResult>,
40 pub best_by_quality: String,
41 pub best_by_reward: String,
42 pub best_by_time: String,
43 pub config: ComparisonConfig,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ComparisonConfig {
49 pub episodes: usize,
50 pub runs: usize,
51 pub dataset_size: usize,
52}
53
54pub struct AlgorithmComparator {
56 config: Config,
57 output_dir: PathBuf,
58}
59
60impl AlgorithmComparator {
61 pub fn new(config: Config, output_dir: PathBuf) -> Result<Self> {
62 std::fs::create_dir_all(&output_dir)?;
63 Ok(Self { config, output_dir })
64 }
65
66 pub fn compare_algorithms(
68 &self,
69 algorithms: Vec<AlgorithmType>,
70 html_samples: Vec<(String, String)>,
71 episodes: usize,
72 runs: usize,
73 ) -> Result<ComparisonReport> {
74 info!("Starting algorithm comparison");
75 info!("Algorithms: {:?}", algorithms);
76 info!("Episodes: {}, Runs per algorithm: {}", episodes, runs);
77
78 let mut results = Vec::new();
79
80 for algorithm in algorithms {
81 info!("Evaluating algorithm: {}", algorithm);
82 let algo_result = self.evaluate_algorithm(
83 algorithm,
84 html_samples.clone(),
85 episodes,
86 runs,
87 )?;
88 results.push(algo_result);
89 }
90
91 let best_by_quality = results.iter()
93 .max_by(|a, b| a.avg_quality.partial_cmp(&b.avg_quality).unwrap())
94 .map(|r| r.algorithm.to_string())
95 .unwrap_or_else(|| "None".to_string());
96
97 let best_by_reward = results.iter()
98 .max_by(|a, b| a.avg_reward.partial_cmp(&b.avg_reward).unwrap())
99 .map(|r| r.algorithm.to_string())
100 .unwrap_or_else(|| "None".to_string());
101
102 let best_by_time = results.iter()
103 .min_by(|a, b| a.avg_training_time.partial_cmp(&b.avg_training_time).unwrap())
104 .map(|r| r.algorithm.to_string())
105 .unwrap_or_else(|| "None".to_string());
106
107 let report = ComparisonReport {
108 algorithms: results,
109 best_by_quality: best_by_quality.clone(),
110 best_by_reward: best_by_reward.clone(),
111 best_by_time: best_by_time.clone(),
112 config: ComparisonConfig {
113 episodes,
114 runs,
115 dataset_size: html_samples.len(),
116 },
117 };
118
119 self.save_report(&report)?;
121
122 self.print_summary(&report);
124
125 Ok(report)
126 }
127
128 fn evaluate_algorithm(
130 &self,
131 algorithm: AlgorithmType,
132 html_samples: Vec<(String, String)>,
133 episodes: usize,
134 runs: usize,
135 ) -> Result<AlgorithmResult> {
136 let mut run_results = Vec::new();
137
138 for run in 0..runs {
139 info!("Algorithm: {}, Run: {}/{}", algorithm, run + 1, runs);
140
141 let start_time = std::time::Instant::now();
142
143 let mut run_config = self.config.clone();
145 run_config.algorithm = algorithm;
146 run_config.num_episodes = episodes;
147
148 let run_samples: Vec<crate::TrainingSample> =
150 html_samples.clone().into_iter().map(Into::into).collect();
151 let (_agent, metrics) = train_standard(&run_config, run_samples)?;
152
153 let training_time = start_time.elapsed().as_secs_f64();
154
155 let final_quality = metrics.episode_qualities.last().copied().unwrap_or(0.0);
157 let final_reward = metrics.episode_rewards.last().copied().unwrap_or(0.0);
158
159 let avg_quality_last_100 = if metrics.episode_qualities.len() >= 100 {
160 metrics.episode_qualities[metrics.episode_qualities.len() - 100..]
161 .iter()
162 .sum::<f32>() / 100.0
163 } else if !metrics.episode_qualities.is_empty() {
164 metrics.episode_qualities.iter().sum::<f32>() / metrics.episode_qualities.len() as f32
165 } else {
166 0.0
167 };
168
169 let avg_reward_last_100 = if metrics.episode_rewards.len() >= 100 {
170 metrics.episode_rewards[metrics.episode_rewards.len() - 100..]
171 .iter()
172 .sum::<f32>() / 100.0
173 } else if !metrics.episode_rewards.is_empty() {
174 metrics.episode_rewards.iter().sum::<f32>() / metrics.episode_rewards.len() as f32
175 } else {
176 0.0
177 };
178
179 let run_result = RunResult {
180 run_number: run,
181 final_quality,
182 final_reward,
183 avg_quality_last_100,
184 avg_reward_last_100,
185 training_time_seconds: training_time,
186 };
187
188 run_results.push(run_result);
189
190 info!("Run {} complete: quality={:.4}, reward={:.4}, time={:.2}s",
191 run + 1, avg_quality_last_100, avg_reward_last_100, training_time);
192 }
193
194 let avg_quality = run_results.iter()
196 .map(|r| r.avg_quality_last_100 as f64)
197 .sum::<f64>() / runs as f64;
198
199 let std_quality = {
200 let variance = run_results.iter()
201 .map(|r| {
202 let diff = r.avg_quality_last_100 as f64 - avg_quality;
203 diff * diff
204 })
205 .sum::<f64>() / runs as f64;
206 variance.sqrt()
207 };
208
209 let avg_reward = run_results.iter()
210 .map(|r| r.avg_reward_last_100 as f64)
211 .sum::<f64>() / runs as f64;
212
213 let std_reward = {
214 let variance = run_results.iter()
215 .map(|r| {
216 let diff = r.avg_reward_last_100 as f64 - avg_reward;
217 diff * diff
218 })
219 .sum::<f64>() / runs as f64;
220 variance.sqrt()
221 };
222
223 let avg_training_time = run_results.iter()
224 .map(|r| r.training_time_seconds)
225 .sum::<f64>() / runs as f64;
226
227 Ok(AlgorithmResult {
228 algorithm,
229 runs: run_results,
230 avg_quality,
231 std_quality,
232 avg_reward,
233 std_reward,
234 avg_training_time,
235 })
236 }
237
238 fn save_report(&self, report: &ComparisonReport) -> Result<()> {
240 let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
241 let path = self.output_dir.join(format!("comparison_report_{}.json", timestamp));
242
243 let json = serde_json::to_string_pretty(report)?;
244 std::fs::write(&path, json)?;
245
246 info!("Comparison report saved to: {}", path.display());
247 Ok(())
248 }
249
250 fn print_summary(&self, report: &ComparisonReport) {
252 println!("\n{}", "=".repeat(80));
253 println!("ALGORITHM COMPARISON RESULTS");
254 println!("{}", "=".repeat(80));
255 println!("Episodes: {}, Runs per algorithm: {}",
256 report.config.episodes, report.config.runs);
257 println!("Dataset size: {}", report.config.dataset_size);
258 println!("{}", "=".repeat(80));
259
260 for result in &report.algorithms {
261 println!("\nAlgorithm: {}", result.algorithm);
262 println!(" Average Quality: {:.4} ± {:.4}", result.avg_quality, result.std_quality);
263 println!(" Average Reward: {:.4} ± {:.4}", result.avg_reward, result.std_reward);
264 println!(" Average Time: {:.2}s", result.avg_training_time);
265 println!(" Individual runs:");
266 for run in &result.runs {
267 println!(" Run {}: quality={:.4}, reward={:.4}, time={:.2}s",
268 run.run_number + 1,
269 run.avg_quality_last_100,
270 run.avg_reward_last_100,
271 run.training_time_seconds);
272 }
273 }
274
275 println!("\n{}", "=".repeat(80));
276 println!("WINNERS");
277 println!("{}", "=".repeat(80));
278 println!("Best Quality: {}", report.best_by_quality);
279 println!("Best Reward: {}", report.best_by_reward);
280 println!("Fastest: {}", report.best_by_time);
281 println!("{}", "=".repeat(80));
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use tempfile::TempDir;
289
290 #[test]
291 fn test_comparison_report_creation() {
292 let run_result = RunResult {
293 run_number: 0,
294 final_quality: 0.8,
295 final_reward: 10.0,
296 avg_quality_last_100: 0.75,
297 avg_reward_last_100: 9.5,
298 training_time_seconds: 100.0,
299 };
300
301 let algo_result = AlgorithmResult {
302 algorithm: AlgorithmType::DuelingDQN,
303 runs: vec![run_result],
304 avg_quality: 0.75,
305 std_quality: 0.05,
306 avg_reward: 9.5,
307 std_reward: 0.5,
308 avg_training_time: 100.0,
309 };
310
311 assert_eq!(algo_result.runs.len(), 1);
312 assert!((algo_result.avg_quality - 0.75).abs() < 0.01);
313 }
314}