use crate::utils::*;
use burn::prelude::*;
use hsl::HSL;
use plotters::{
prelude::*,
style::text_anchor::{HPos, Pos, VPos},
};
use std::collections::HashSet;
const CAPTION: &str = "fast-umap";
const PATH: &str = "figures/plot.png";
#[derive(Debug, Clone)]
pub struct ChartConfig {
pub caption: String,
pub path: String,
pub width: u32,
pub height: u32,
}
impl ChartConfig {
pub fn builder() -> ChartConfigBuilder {
ChartConfigBuilder {
caption: Some(CAPTION.to_string()),
path: Some(PATH.to_string()),
width: Some(1000),
height: Some(1000),
}
}
}
impl Default for ChartConfig {
fn default() -> Self {
ChartConfig {
caption: CAPTION.to_string(),
path: PATH.to_string(),
width: 1000,
height: 1000,
}
}
}
pub struct ChartConfigBuilder {
caption: Option<String>,
path: Option<String>,
width: Option<u32>,
height: Option<u32>,
}
impl Default for ChartConfigBuilder {
fn default() -> Self {
ChartConfigBuilder {
caption: Some(CAPTION.into()),
path: Some(PATH.into()),
width: None,
height: None,
}
}
}
impl ChartConfigBuilder {
pub fn caption(mut self, caption: &str) -> Self {
self.caption = Some(caption.to_string());
self
}
pub fn path(mut self, path: &str) -> Self {
self.path = Some(path.to_string());
self
}
pub fn width(mut self, width: u32) -> Self {
self.width = Some(width);
self
}
pub fn height(mut self, height: u32) -> Self {
self.height = Some(height);
self
}
pub fn build(self) -> ChartConfig {
ChartConfig {
caption: self.caption.unwrap_or_else(|| CAPTION.to_string()),
path: self.path.unwrap_or_else(|| PATH.to_string()),
width: self.width.unwrap_or(1000),
height: self.height.unwrap_or(1000),
}
}
}
type Float = f64;
pub fn chart_tensor<B: Backend>(
data: Tensor<B, 2>,
labels: Option<Vec<String>>,
config: Option<ChartConfig>,
) {
let data: Vec<Vec<Float>> = convert_tensor_to_vector(data);
chart_vector(data, labels, config);
}
pub fn plot_loss<F: num::Float>(
losses: Vec<F>,
output_path: &str,
) -> Result<(), Box<dyn std::error::Error>>
where
F:,
{
if let Some(parent) = std::path::Path::new(output_path).parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)?;
}
}
let min_loss = losses.iter().cloned().fold(F::infinity(), F::min);
let max_loss = losses.iter().cloned().fold(F::neg_infinity(), F::max);
let padding = F::from(0.1).unwrap(); let min_loss_with_padding = min_loss - padding * min_loss.abs();
let max_loss_with_padding = max_loss + padding * max_loss.abs();
let min_loss_with_padding = min_loss_with_padding.to_f64().unwrap();
let max_loss_with_padding = max_loss_with_padding.to_f64().unwrap();
let root = BitMapBackend::new(output_path, (800, 600)).into_drawing_area();
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption("Loss Over Epochs", ("sans-serif", 30))
.set_label_area_size(LabelAreaPosition::Left, 80)
.set_label_area_size(LabelAreaPosition::Bottom, 50)
.build_cartesian_2d(
0..losses.len() as u32,
min_loss_with_padding..max_loss_with_padding,
)?;
chart
.configure_mesh()
.y_desc("Loss")
.x_desc("Epochs")
.draw()?;
chart
.draw_series(LineSeries::new(
(0..losses.len()).map(|x| (x as u32, losses[x].to_f64().unwrap())),
&BLUE,
))?
.label("Loss")
.legend(move |(x, y)| PathElement::new(vec![(x, y)], &RED));
chart.configure_series_labels().draw()?;
chart.configure_mesh().y_labels(10).draw()?;
Ok(())
}
pub fn chart_vector(
data: Vec<Vec<Float>>,
labels: Option<Vec<String>>,
config: Option<ChartConfig>,
) {
let config = config.unwrap_or(ChartConfig::default());
if let Some(parent) = std::path::Path::new(&config.path).parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent).expect("Could not create chart output directory");
}
}
let root = BitMapBackend::new(&config.path, (config.width, config.height)).into_drawing_area();
root.fill(&WHITE).unwrap();
let min_x = data
.iter()
.flat_map(|v| v.iter().step_by(2))
.cloned()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap() as Float;
let max_x = data
.iter()
.flat_map(|v| v.iter().step_by(2))
.cloned()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap() as Float;
let min_y = data
.iter()
.flat_map(|v| v.iter().skip(1).step_by(2))
.cloned()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap() as Float;
let max_y = data
.iter()
.flat_map(|v| v.iter().skip(1).step_by(2))
.cloned()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap() as Float;
let mut label_colors: Vec<(String, RGBColor)> = Vec::new();
if let Some(labels) = labels.clone() {
let mut unique_labels: Vec<String> = labels
.into_iter()
.collect::<HashSet<String>>()
.into_iter()
.collect();
unique_labels.sort_by(|a, b| match (a.parse::<f64>(), b.parse::<f64>()) {
(Ok(af), Ok(bf)) => af.partial_cmp(&bf).unwrap_or_else(|| a.cmp(b)),
_ => a.cmp(b),
});
for (i, label) in unique_labels.iter().enumerate() {
let hue = i as f64 * 360.0 / unique_labels.len() as f64;
let color = HSL { h: hue, s: 0.7, l: 0.5 }.to_rgb();
label_colors.push((label.clone(), RGBColor(color.0, color.1, color.2)));
}
}
let mut chart = ChartBuilder::on(&root)
.caption(config.caption, ("sans-serif", 30))
.margin(40)
.x_label_area_size(30)
.y_label_area_size(30)
.build_cartesian_2d(min_x..max_x, min_y..max_y)
.unwrap();
chart
.configure_mesh()
.x_desc("X Axis")
.y_desc("Y Axis")
.x_labels(10)
.y_labels(10)
.draw()
.unwrap();
let mut series_list: Vec<(String, Vec<(f64, f64)>, RGBColor)> = Vec::new();
chart
.draw_series(data.iter().enumerate().map(|(i, values)| {
let label = labels
.clone()
.map(|l| l.get(i).cloned())
.flatten()
.unwrap_or_else(|| "".into());
let color = label_colors
.iter()
.find(|(l, _)| *l == label)
.map(|(_, c)| *c)
.unwrap_or(RED);
if !label.is_empty() {
let series_data = series_list.iter_mut().find(|(l, _, _)| *l == label);
match series_data {
Some((_, series_points, _)) => series_points.push((values[0], values[1])),
None => series_list.push((label.clone(), vec![(values[0], values[1])], color)),
}
}
Circle::new(
(values[0], values[1]),
3,
ShapeStyle {
color: color.into(),
filled: false,
stroke_width: 1,
},
)
}))
.unwrap();
if labels.is_some() {
series_list.sort_by(|a, b| {
let a = a.0.parse::<usize>().unwrap();
let b = b.0.parse::<usize>().unwrap();
a.cmp(&b)
});
let spacing_y = (max_y - min_y) / (series_list.len() * 2) as f64;
let mut legend_position = (min_x + (max_x - min_x) * 0.8, max_y - (max_y - min_y) * 0.1);
let size = 5.0; let font_size = 15.0;
for (label, _, color) in series_list {
chart
.draw_series(std::iter::once(Circle::new(
legend_position,
size,
ShapeStyle {
color: color.into(),
filled: true,
stroke_width: 1,
},
)))
.unwrap();
let style = TextStyle {
font: ("sans-serif", font_size).into_font(),
color: BLACK.to_backend_color(),
pos: Pos::new(HPos::Left, VPos::Center),
};
chart
.draw_series(std::iter::once(Text::new(
label,
(legend_position.0 + spacing_y / 4.0, legend_position.1),
style,
)))
.unwrap();
legend_position.1 -= spacing_y;
}
}
root.present().unwrap();
}