use crate::utils::*;
use burn::prelude::*;
use plotters::prelude::*;
const CAPTION: &str = "fast-umap";
const PATH: &str = "plot.png";
#[derive(Debug, Clone)]
pub struct ChartConfig {
pub caption: String,
pub path: String,
pub width: u32,
pub height: u32,
}
impl ChartConfig {
pub fn builder() -> ChartConfigBuilder {
ChartConfigBuilder {
caption: Some(CAPTION.to_string()),
path: Some(PATH.to_string()),
width: Some(1000),
height: Some(1000),
}
}
}
impl Default for ChartConfig {
fn default() -> Self {
ChartConfig {
caption: CAPTION.to_string(),
path: PATH.to_string(),
width: 1000,
height: 1000,
}
}
}
pub struct ChartConfigBuilder {
caption: Option<String>,
path: Option<String>,
width: Option<u32>,
height: Option<u32>,
}
impl ChartConfigBuilder {
pub fn caption(mut self, caption: &str) -> Self {
self.caption = Some(caption.to_string());
self
}
pub fn path(mut self, path: &str) -> Self {
self.path = Some(path.to_string());
self
}
pub fn width(mut self, width: u32) -> Self {
self.width = Some(width);
self
}
pub fn height(mut self, height: u32) -> Self {
self.height = Some(height);
self
}
pub fn build(self) -> ChartConfig {
ChartConfig {
caption: self.caption.unwrap_or_else(|| CAPTION.to_string()),
path: self.path.unwrap_or_else(|| PATH.to_string()),
width: self.width.unwrap_or(1000),
height: self.height.unwrap_or(1000),
}
}
}
type Float = f64;
pub fn chart_tensor<B: Backend>(data: Tensor<B, 2>, config: Option<ChartConfig>) {
let data: Vec<Vec<Float>> = convert_tensor_to_vector(data);
chart_vector(data, config);
}
pub fn chart_vector(data: Vec<Vec<Float>>, config: Option<ChartConfig>) {
let config = config.unwrap_or(ChartConfig::default());
let root = BitMapBackend::new(&config.path, (config.width, config.height)).into_drawing_area();
root.fill(&WHITE).unwrap();
let min_x = data
.iter()
.flat_map(|v| v.iter().step_by(2)) .cloned()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap() as Float;
let max_x = data
.iter()
.flat_map(|v| v.iter().step_by(2)) .cloned()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap() as Float;
let min_y = data
.iter()
.flat_map(|v| v.iter().skip(1).step_by(2)) .cloned()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap() as Float;
let max_y = data
.iter()
.flat_map(|v| v.iter().skip(1).step_by(2)) .cloned()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap() as Float;
let mut chart = ChartBuilder::on(&root)
.caption(config.caption, ("sans-serif", 30))
.margin(40)
.x_label_area_size(30)
.y_label_area_size(30)
.build_cartesian_2d(min_x..max_x, min_y..max_y)
.unwrap();
chart
.configure_mesh()
.x_desc("X Axis")
.y_desc("Y Axis")
.x_labels(10)
.y_labels(10)
.draw()
.unwrap();
chart
.draw_series(data.iter().map(|values| {
Circle::new(
(values[0], values[1]),
5,
ShapeStyle {
color: RED.to_rgba(),
filled: true,
stroke_width: 1,
},
)
}))
.unwrap()
.label("UMAP")
.legend(move |(x, y)| {
Circle::new(
(x, y),
5,
ShapeStyle {
color: RED.to_rgba(),
filled: true,
stroke_width: 1,
},
)
});
chart.configure_mesh().draw().unwrap();
root.present().unwrap();
}
pub fn plot_loss(losses: Vec<f64>, output_path: &str) -> Result<(), Box<dyn std::error::Error>> {
let min_loss = losses.iter().cloned().fold(f64::INFINITY, f64::min);
let max_loss = losses.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let padding = 0.1; let min_loss_with_padding = min_loss - padding * min_loss.abs();
let max_loss_with_padding = max_loss + padding * max_loss.abs();
let root = BitMapBackend::new(output_path, (800, 600)).into_drawing_area();
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption("Loss Over Epochs", ("sans-serif", 30))
.set_label_area_size(LabelAreaPosition::Left, 80)
.set_label_area_size(LabelAreaPosition::Bottom, 50)
.build_cartesian_2d(
0..losses.len() as u32,
min_loss_with_padding..max_loss_with_padding,
)?;
chart
.configure_mesh()
.y_desc("Loss")
.x_desc("Epochs")
.draw()?;
chart
.draw_series(LineSeries::new(
(0..losses.len()).map(|x| (x as u32, losses[x])),
&BLUE,
))?
.label("Loss")
.legend(move |(x, y)| PathElement::new(vec![(x, y)], &RED));
chart.configure_series_labels().draw()?;
chart.configure_mesh().y_labels(10).draw()?;
Ok(())
}