1use crate::{Result, training::TrainingMetrics};
7use plotters::prelude::*;
8use std::path::Path;
9use tracing::info;
10
11pub struct PlotConfig {
13 pub width: u32,
14 pub height: u32,
15 pub dpi: u32,
16}
17
18impl Default for PlotConfig {
19 fn default() -> Self {
20 Self {
21 width: 1600,
22 height: 1200,
23 dpi: 150,
24 }
25 }
26}
27
28pub struct TrainingPlotter {
30 config: PlotConfig,
31}
32
33impl TrainingPlotter {
34 pub fn new() -> Self {
36 Self {
37 config: PlotConfig::default(),
38 }
39 }
40
41 pub fn with_config(config: PlotConfig) -> Self {
43 Self { config }
44 }
45
46 pub fn plot_training_results(&self, metrics: &TrainingMetrics, output_path: &Path) -> Result<()> {
48 info!("Generating training plots to: {}", output_path.display());
49
50 let root = BitMapBackend::new(
51 output_path,
52 (self.config.width, self.config.height)
53 ).into_drawing_area();
54
55 root.fill(&WHITE)
56 .map_err(|e| crate::ExtractionError::ModelError(format!("Plot fill error: {}", e)))?;
57
58 let areas = root.split_evenly((2, 2));
60
61 self.plot_rewards(&areas[0], &metrics.episode_rewards)?;
63
64 self.plot_quality(&areas[1], &metrics.episode_qualities)?;
66
67 self.plot_reward_distribution(&areas[2], &metrics.episode_rewards)?;
69
70 self.plot_quality_distribution(&areas[3], &metrics.episode_qualities)?;
72
73 root.present()
74 .map_err(|e| crate::ExtractionError::ModelError(format!("Plot present error: {}", e)))?;
75
76 info!("Training plots saved successfully");
77 Ok(())
78 }
79
80 fn plot_rewards<DB: DrawingBackend>(
82 &self,
83 area: &DrawingArea<DB, plotters::coord::Shift>,
84 rewards: &[f32],
85 ) -> Result<()>
86 where
87 DB::ErrorType: 'static,
88 {
89 if rewards.is_empty() {
90 return Ok(());
91 }
92
93 let max_episodes = rewards.len();
94 let max_reward = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
95 let min_reward = rewards.iter().copied().fold(f32::INFINITY, f32::min);
96
97 let mut chart = ChartBuilder::on(area)
98 .caption("Episode Rewards", ("sans-serif", 30).into_font())
99 .margin(10)
100 .x_label_area_size(30)
101 .y_label_area_size(50)
102 .build_cartesian_2d(0..max_episodes, min_reward..max_reward)
103 .map_err(|e| crate::ExtractionError::ModelError(format!("Chart build error: {}", e)))?;
104
105 chart.configure_mesh()
106 .x_desc("Episode")
107 .y_desc("Reward")
108 .draw()
109 .map_err(|e| crate::ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
110
111 chart.draw_series(LineSeries::new(
113 rewards.iter().enumerate().map(|(i, &r)| (i, r)),
114 &BLUE.mix(0.5),
115 ))
116 .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
117 .label("Raw")
118 .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], BLUE));
119
120 if rewards.len() > 100 {
122 let window = rewards.len().min(100);
123 let moving_avg = self.calculate_moving_average(rewards, window);
124
125 chart.draw_series(LineSeries::new(
126 moving_avg.into_iter()
127 .enumerate()
128 .map(|(i, avg)| (i + window - 1, avg)),
129 &RED,
130 ))
131 .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
132 .label(format!("MA({})", window))
133 .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
134 }
135
136 chart.configure_series_labels()
137 .background_style(WHITE.mix(0.8))
138 .border_style(BLACK)
139 .draw()
140 .map_err(|e| crate::ExtractionError::ModelError(format!("Legend error: {}", e)))?;
141
142 Ok(())
143 }
144
145 fn plot_quality<DB: DrawingBackend>(
147 &self,
148 area: &DrawingArea<DB, plotters::coord::Shift>,
149 qualities: &[f32],
150 ) -> Result<()>
151 where
152 DB::ErrorType: 'static,
153 {
154 if qualities.is_empty() {
155 return Ok(());
156 }
157
158 let max_episodes = qualities.len();
159 let max_quality = qualities.iter().copied().fold(f32::NEG_INFINITY, f32::max);
160
161 let mut chart = ChartBuilder::on(area)
162 .caption("Episode Quality", ("sans-serif", 30).into_font())
163 .margin(10)
164 .x_label_area_size(30)
165 .y_label_area_size(50)
166 .build_cartesian_2d(0..max_episodes, 0.0..max_quality.max(1.0))
167 .map_err(|e| crate::ExtractionError::ModelError(format!("Chart build error: {}", e)))?;
168
169 chart.configure_mesh()
170 .x_desc("Episode")
171 .y_desc("Quality Score")
172 .draw()
173 .map_err(|e| crate::ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
174
175 chart.draw_series(LineSeries::new(
177 qualities.iter().enumerate().map(|(i, &q)| (i, q)),
178 &GREEN.mix(0.5),
179 ))
180 .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
181 .label("Raw")
182 .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], GREEN));
183
184 if qualities.len() > 100 {
186 let window = qualities.len().min(100);
187 let moving_avg = self.calculate_moving_average(qualities, window);
188
189 chart.draw_series(LineSeries::new(
190 moving_avg.into_iter()
191 .enumerate()
192 .map(|(i, avg)| (i + window - 1, avg)),
193 &RED,
194 ))
195 .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
196 .label(format!("MA({})", window))
197 .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
198 }
199
200 chart.configure_series_labels()
201 .background_style(WHITE.mix(0.8))
202 .border_style(BLACK)
203 .draw()
204 .map_err(|e| crate::ExtractionError::ModelError(format!("Legend error: {}", e)))?;
205
206 Ok(())
207 }
208
209 fn plot_reward_distribution<DB: DrawingBackend>(
211 &self,
212 area: &DrawingArea<DB, plotters::coord::Shift>,
213 rewards: &[f32],
214 ) -> Result<()>
215 where
216 DB::ErrorType: 'static,
217 {
218 if rewards.is_empty() {
219 return Ok(());
220 }
221
222 let max_reward = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
223 let min_reward = rewards.iter().copied().fold(f32::INFINITY, f32::min);
224
225 let n_bins = 50;
227 let bin_width = (max_reward - min_reward) / n_bins as f32;
228 let mut histogram = vec![0usize; n_bins];
229
230 for &reward in rewards {
231 let bin = ((reward - min_reward) / bin_width).floor() as usize;
232 let bin = bin.min(n_bins - 1);
233 histogram[bin] += 1;
234 }
235
236 let max_count = *histogram.iter().max().unwrap_or(&1);
237
238 let mut chart = ChartBuilder::on(area)
239 .caption("Reward Distribution", ("sans-serif", 30).into_font())
240 .margin(10)
241 .x_label_area_size(30)
242 .y_label_area_size(50)
243 .build_cartesian_2d(min_reward..max_reward, 0..max_count)
244 .map_err(|e| crate::ExtractionError::ModelError(format!("Chart build error: {}", e)))?;
245
246 chart.configure_mesh()
247 .x_desc("Reward")
248 .y_desc("Frequency")
249 .draw()
250 .map_err(|e| crate::ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
251
252 chart.draw_series(
254 histogram.iter().enumerate().map(|(i, &count)| {
255 let x0 = min_reward + i as f32 * bin_width;
256 let x1 = x0 + bin_width;
257 Rectangle::new([(x0, 0), (x1, count)], BLUE.mix(0.7).filled())
258 })
259 )
260 .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?;
261
262 let mean = rewards.iter().sum::<f32>() / rewards.len() as f32;
264 chart.draw_series(LineSeries::new(
265 vec![(mean, 0), (mean, max_count)],
266 RED.stroke_width(2),
267 ))
268 .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
269 .label(format!("Mean: {:.3}", mean))
270 .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
271
272 chart.configure_series_labels()
273 .background_style(WHITE.mix(0.8))
274 .border_style(BLACK)
275 .draw()
276 .map_err(|e| crate::ExtractionError::ModelError(format!("Legend error: {}", e)))?;
277
278 Ok(())
279 }
280
281 fn plot_quality_distribution<DB: DrawingBackend>(
283 &self,
284 area: &DrawingArea<DB, plotters::coord::Shift>,
285 qualities: &[f32],
286 ) -> Result<()>
287 where
288 DB::ErrorType: 'static,
289 {
290 if qualities.is_empty() {
291 return Ok(());
292 }
293
294 let max_quality = qualities.iter().copied().fold(f32::NEG_INFINITY, f32::max);
295 let min_quality = qualities.iter().copied().fold(f32::INFINITY, f32::min);
296
297 let n_bins = 50;
299 let bin_width = (max_quality - min_quality).max(0.01) / n_bins as f32;
300 let mut histogram = vec![0usize; n_bins];
301
302 for &quality in qualities {
303 let bin = ((quality - min_quality) / bin_width).floor() as usize;
304 let bin = bin.min(n_bins - 1);
305 histogram[bin] += 1;
306 }
307
308 let max_count = *histogram.iter().max().unwrap_or(&1);
309
310 let mut chart = ChartBuilder::on(area)
311 .caption("Quality Distribution", ("sans-serif", 30).into_font())
312 .margin(10)
313 .x_label_area_size(30)
314 .y_label_area_size(50)
315 .build_cartesian_2d(min_quality..max_quality.max(1.0), 0..max_count)
316 .map_err(|e| crate::ExtractionError::ModelError(format!("Chart build error: {}", e)))?;
317
318 chart.configure_mesh()
319 .x_desc("Quality Score")
320 .y_desc("Frequency")
321 .draw()
322 .map_err(|e| crate::ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
323
324 chart.draw_series(
326 histogram.iter().enumerate().map(|(i, &count)| {
327 let x0 = min_quality + i as f32 * bin_width;
328 let x1 = x0 + bin_width;
329 Rectangle::new([(x0, 0), (x1, count)], GREEN.mix(0.7).filled())
330 })
331 )
332 .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?;
333
334 let mean = qualities.iter().sum::<f32>() / qualities.len() as f32;
336 chart.draw_series(LineSeries::new(
337 vec![(mean, 0), (mean, max_count)],
338 RED.stroke_width(2),
339 ))
340 .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
341 .label(format!("Mean: {:.3}", mean))
342 .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
343
344 chart.configure_series_labels()
345 .background_style(WHITE.mix(0.8))
346 .border_style(BLACK)
347 .draw()
348 .map_err(|e| crate::ExtractionError::ModelError(format!("Legend error: {}", e)))?;
349
350 Ok(())
351 }
352
353 fn calculate_moving_average(&self, data: &[f32], window: usize) -> Vec<f32> {
355 let mut result = Vec::with_capacity(data.len() - window + 1);
356
357 for i in window - 1..data.len() {
358 let sum: f32 = data[i - window + 1..=i].iter().sum();
359 result.push(sum / window as f32);
360 }
361
362 result
363 }
364
365 pub fn plot_intermediate(&self, metrics: &TrainingMetrics, output_path: &Path, episode: usize) -> Result<()> {
367 let timestamped_path = output_path.parent().unwrap().join(
368 format!("training_plot_ep{}.png", episode)
369 );
370
371 self.plot_training_results(metrics, ×tamped_path)
372 }
373}
374
375impl Default for TrainingPlotter {
376 fn default() -> Self {
377 Self::new()
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use tempfile::TempDir;
385
386 #[test]
387 fn test_plot_generation() {
388 let temp_dir = TempDir::new().unwrap();
389 let plot_path = temp_dir.path().join("test_plot.png");
390
391 let metrics = TrainingMetrics {
392 episode_rewards: (0..100).map(|i| (i as f32 * 0.01) - 0.5).collect(),
393 episode_qualities: (0..100).map(|i| i as f32 * 0.01).collect(),
394 episode_losses: vec![],
395 best_avg_quality: 0.9,
396 };
397
398 let plotter = TrainingPlotter::new();
399 plotter.plot_training_results(&metrics, &plot_path).unwrap();
400
401 assert!(plot_path.exists());
402 }
403}