Skip to main content

content_extractor_rl/
plotting.rs

1//! Training visualization and plotting using plotters library
2// ============================================================================
3// FILE: crates/content-extractor-rl/src/plotting.rs
4// ============================================================================
5
6use crate::{Result, training::TrainingMetrics};
7use plotters::prelude::*;
8use std::path::Path;
9use tracing::info;
10
11/// Plot configuration
12pub struct PlotConfig {
13    pub width: u32,
14    pub height: u32,
15    pub dpi: u32,
16}
17
18impl Default for PlotConfig {
19    fn default() -> Self {
20        Self {
21            width: 1600,
22            height: 1200,
23            dpi: 150,
24        }
25    }
26}
27
28/// Training plots generator
29pub struct TrainingPlotter {
30    config: PlotConfig,
31}
32
33impl TrainingPlotter {
34    /// Create new plotter with default config
35    pub fn new() -> Self {
36        Self {
37            config: PlotConfig::default(),
38        }
39    }
40
41    /// Create plotter with custom config
42    pub fn with_config(config: PlotConfig) -> Self {
43        Self { config }
44    }
45
46    /// Generate comprehensive training plots
47    pub fn plot_training_results(&self, metrics: &TrainingMetrics, output_path: &Path) -> Result<()> {
48        info!("Generating training plots to: {}", output_path.display());
49
50        let root = BitMapBackend::new(
51            output_path,
52            (self.config.width, self.config.height)
53        ).into_drawing_area();
54
55        root.fill(&WHITE)
56            .map_err(|e| crate::ExtractionError::ModelError(format!("Plot fill error: {}", e)))?;
57
58        // Split into 2x2 grid
59        let areas = root.split_evenly((2, 2));
60
61        // Plot 1: Episode Rewards
62        self.plot_rewards(&areas[0], &metrics.episode_rewards)?;
63
64        // Plot 2: Episode Quality
65        self.plot_quality(&areas[1], &metrics.episode_qualities)?;
66
67        // Plot 3: Reward Distribution
68        self.plot_reward_distribution(&areas[2], &metrics.episode_rewards)?;
69
70        // Plot 4: Quality Distribution
71        self.plot_quality_distribution(&areas[3], &metrics.episode_qualities)?;
72
73        root.present()
74            .map_err(|e| crate::ExtractionError::ModelError(format!("Plot present error: {}", e)))?;
75
76        info!("Training plots saved successfully");
77        Ok(())
78    }
79
80    /// Plot episode rewards over time with moving average
81    fn plot_rewards<DB: DrawingBackend>(
82        &self,
83        area: &DrawingArea<DB, plotters::coord::Shift>,
84        rewards: &[f32],
85    ) -> Result<()>
86    where
87        DB::ErrorType: 'static,
88    {
89        if rewards.is_empty() {
90            return Ok(());
91        }
92
93        let max_episodes = rewards.len();
94        let max_reward = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
95        let min_reward = rewards.iter().copied().fold(f32::INFINITY, f32::min);
96
97        let mut chart = ChartBuilder::on(area)
98            .caption("Episode Rewards", ("sans-serif", 30).into_font())
99            .margin(10)
100            .x_label_area_size(30)
101            .y_label_area_size(50)
102            .build_cartesian_2d(0..max_episodes, min_reward..max_reward)
103            .map_err(|e| crate::ExtractionError::ModelError(format!("Chart build error: {}", e)))?;
104
105        chart.configure_mesh()
106            .x_desc("Episode")
107            .y_desc("Reward")
108            .draw()
109            .map_err(|e| crate::ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
110
111        // Plot raw rewards
112        chart.draw_series(LineSeries::new(
113            rewards.iter().enumerate().map(|(i, &r)| (i, r)),
114            &BLUE.mix(0.5),
115        ))
116            .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
117            .label("Raw")
118            .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], BLUE));
119
120        // Plot moving average
121        if rewards.len() > 100 {
122            let window = rewards.len().min(100);
123            let moving_avg = self.calculate_moving_average(rewards, window);
124
125            chart.draw_series(LineSeries::new(
126                moving_avg.into_iter()
127                    .enumerate()
128                    .map(|(i, avg)| (i + window - 1, avg)),
129                &RED,
130            ))
131                .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
132                .label(format!("MA({})", window))
133                .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
134        }
135
136        chart.configure_series_labels()
137            .background_style(WHITE.mix(0.8))
138            .border_style(BLACK)
139            .draw()
140            .map_err(|e| crate::ExtractionError::ModelError(format!("Legend error: {}", e)))?;
141
142        Ok(())
143    }
144
145    /// Plot episode quality over time with moving average
146    fn plot_quality<DB: DrawingBackend>(
147        &self,
148        area: &DrawingArea<DB, plotters::coord::Shift>,
149        qualities: &[f32],
150    ) -> Result<()>
151    where
152        DB::ErrorType: 'static,
153    {
154        if qualities.is_empty() {
155            return Ok(());
156        }
157
158        let max_episodes = qualities.len();
159        let max_quality = qualities.iter().copied().fold(f32::NEG_INFINITY, f32::max);
160
161        let mut chart = ChartBuilder::on(area)
162            .caption("Episode Quality", ("sans-serif", 30).into_font())
163            .margin(10)
164            .x_label_area_size(30)
165            .y_label_area_size(50)
166            .build_cartesian_2d(0..max_episodes, 0.0..max_quality.max(1.0))
167            .map_err(|e| crate::ExtractionError::ModelError(format!("Chart build error: {}", e)))?;
168
169        chart.configure_mesh()
170            .x_desc("Episode")
171            .y_desc("Quality Score")
172            .draw()
173            .map_err(|e| crate::ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
174
175        // Plot raw quality
176        chart.draw_series(LineSeries::new(
177            qualities.iter().enumerate().map(|(i, &q)| (i, q)),
178            &GREEN.mix(0.5),
179        ))
180            .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
181            .label("Raw")
182            .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], GREEN));
183
184        // Plot moving average
185        if qualities.len() > 100 {
186            let window = qualities.len().min(100);
187            let moving_avg = self.calculate_moving_average(qualities, window);
188
189            chart.draw_series(LineSeries::new(
190                moving_avg.into_iter()
191                    .enumerate()
192                    .map(|(i, avg)| (i + window - 1, avg)),
193                &RED,
194            ))
195                .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
196                .label(format!("MA({})", window))
197                .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
198        }
199
200        chart.configure_series_labels()
201            .background_style(WHITE.mix(0.8))
202            .border_style(BLACK)
203            .draw()
204            .map_err(|e| crate::ExtractionError::ModelError(format!("Legend error: {}", e)))?;
205
206        Ok(())
207    }
208
209    /// Plot reward distribution histogram
210    fn plot_reward_distribution<DB: DrawingBackend>(
211        &self,
212        area: &DrawingArea<DB, plotters::coord::Shift>,
213        rewards: &[f32],
214    ) -> Result<()>
215    where
216        DB::ErrorType: 'static,
217    {
218        if rewards.is_empty() {
219            return Ok(());
220        }
221
222        let max_reward = rewards.iter().copied().fold(f32::NEG_INFINITY, f32::max);
223        let min_reward = rewards.iter().copied().fold(f32::INFINITY, f32::min);
224
225        // Calculate histogram
226        let n_bins = 50;
227        let bin_width = (max_reward - min_reward) / n_bins as f32;
228        let mut histogram = vec![0usize; n_bins];
229
230        for &reward in rewards {
231            let bin = ((reward - min_reward) / bin_width).floor() as usize;
232            let bin = bin.min(n_bins - 1);
233            histogram[bin] += 1;
234        }
235
236        let max_count = *histogram.iter().max().unwrap_or(&1);
237
238        let mut chart = ChartBuilder::on(area)
239            .caption("Reward Distribution", ("sans-serif", 30).into_font())
240            .margin(10)
241            .x_label_area_size(30)
242            .y_label_area_size(50)
243            .build_cartesian_2d(min_reward..max_reward, 0..max_count)
244            .map_err(|e| crate::ExtractionError::ModelError(format!("Chart build error: {}", e)))?;
245
246        chart.configure_mesh()
247            .x_desc("Reward")
248            .y_desc("Frequency")
249            .draw()
250            .map_err(|e| crate::ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
251
252        // Draw histogram bars
253        chart.draw_series(
254            histogram.iter().enumerate().map(|(i, &count)| {
255                let x0 = min_reward + i as f32 * bin_width;
256                let x1 = x0 + bin_width;
257                Rectangle::new([(x0, 0), (x1, count)], BLUE.mix(0.7).filled())
258            })
259        )
260            .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?;
261
262        // Draw mean line
263        let mean = rewards.iter().sum::<f32>() / rewards.len() as f32;
264        chart.draw_series(LineSeries::new(
265            vec![(mean, 0), (mean, max_count)],
266            RED.stroke_width(2),
267        ))
268            .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
269            .label(format!("Mean: {:.3}", mean))
270            .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
271
272        chart.configure_series_labels()
273            .background_style(WHITE.mix(0.8))
274            .border_style(BLACK)
275            .draw()
276            .map_err(|e| crate::ExtractionError::ModelError(format!("Legend error: {}", e)))?;
277
278        Ok(())
279    }
280
281    /// Plot quality distribution histogram
282    fn plot_quality_distribution<DB: DrawingBackend>(
283        &self,
284        area: &DrawingArea<DB, plotters::coord::Shift>,
285        qualities: &[f32],
286    ) -> Result<()>
287    where
288        DB::ErrorType: 'static,
289    {
290        if qualities.is_empty() {
291            return Ok(());
292        }
293
294        let max_quality = qualities.iter().copied().fold(f32::NEG_INFINITY, f32::max);
295        let min_quality = qualities.iter().copied().fold(f32::INFINITY, f32::min);
296
297        // Calculate histogram
298        let n_bins = 50;
299        let bin_width = (max_quality - min_quality).max(0.01) / n_bins as f32;
300        let mut histogram = vec![0usize; n_bins];
301
302        for &quality in qualities {
303            let bin = ((quality - min_quality) / bin_width).floor() as usize;
304            let bin = bin.min(n_bins - 1);
305            histogram[bin] += 1;
306        }
307
308        let max_count = *histogram.iter().max().unwrap_or(&1);
309
310        let mut chart = ChartBuilder::on(area)
311            .caption("Quality Distribution", ("sans-serif", 30).into_font())
312            .margin(10)
313            .x_label_area_size(30)
314            .y_label_area_size(50)
315            .build_cartesian_2d(min_quality..max_quality.max(1.0), 0..max_count)
316            .map_err(|e| crate::ExtractionError::ModelError(format!("Chart build error: {}", e)))?;
317
318        chart.configure_mesh()
319            .x_desc("Quality Score")
320            .y_desc("Frequency")
321            .draw()
322            .map_err(|e| crate::ExtractionError::ModelError(format!("Mesh error: {}", e)))?;
323
324        // Draw histogram bars
325        chart.draw_series(
326            histogram.iter().enumerate().map(|(i, &count)| {
327                let x0 = min_quality + i as f32 * bin_width;
328                let x1 = x0 + bin_width;
329                Rectangle::new([(x0, 0), (x1, count)], GREEN.mix(0.7).filled())
330            })
331        )
332            .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?;
333
334        // Draw mean line
335        let mean = qualities.iter().sum::<f32>() / qualities.len() as f32;
336        chart.draw_series(LineSeries::new(
337            vec![(mean, 0), (mean, max_count)],
338            RED.stroke_width(2),
339        ))
340            .map_err(|e| crate::ExtractionError::ModelError(format!("Series error: {}", e)))?
341            .label(format!("Mean: {:.3}", mean))
342            .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED));
343
344        chart.configure_series_labels()
345            .background_style(WHITE.mix(0.8))
346            .border_style(BLACK)
347            .draw()
348            .map_err(|e| crate::ExtractionError::ModelError(format!("Legend error: {}", e)))?;
349
350        Ok(())
351    }
352
353    /// Calculate moving average
354    fn calculate_moving_average(&self, data: &[f32], window: usize) -> Vec<f32> {
355        let mut result = Vec::with_capacity(data.len() - window + 1);
356
357        for i in window - 1..data.len() {
358            let sum: f32 = data[i - window + 1..=i].iter().sum();
359            result.push(sum / window as f32);
360        }
361
362        result
363    }
364
365    /// Generate plot periodically during training
366    pub fn plot_intermediate(&self, metrics: &TrainingMetrics, output_path: &Path, episode: usize) -> Result<()> {
367        let timestamped_path = output_path.parent().unwrap().join(
368            format!("training_plot_ep{}.png", episode)
369        );
370
371        self.plot_training_results(metrics, &timestamped_path)
372    }
373}
374
375impl Default for TrainingPlotter {
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use tempfile::TempDir;
385
386    #[test]
387    fn test_plot_generation() {
388        let temp_dir = TempDir::new().unwrap();
389        let plot_path = temp_dir.path().join("test_plot.png");
390
391        let metrics = TrainingMetrics {
392            episode_rewards: (0..100).map(|i| (i as f32 * 0.01) - 0.5).collect(),
393            episode_qualities: (0..100).map(|i| i as f32 * 0.01).collect(),
394            episode_losses: vec![],
395            best_avg_quality: 0.9,
396        };
397
398        let plotter = TrainingPlotter::new();
399        plotter.plot_training_results(&metrics, &plot_path).unwrap();
400
401        assert!(plot_path.exists());
402    }
403}