use std::collections::HashMap;
use ndarray::Array2;
use num_traits::ToPrimitive;
use plotters::prelude::*;
use crate::symbolic::core::DagOp;
use crate::symbolic::core::Expr;
#[derive(Clone, Debug)]
pub struct PlotConfig {
pub width: u32,
pub height: u32,
pub caption: String,
pub line_color: RGBAColor,
pub mesh_color: RGBAColor,
pub samples: usize,
pub pitch: f64,
pub yaw: f64,
pub scale: f64,
pub label_font_size: u32,
pub caption_font_size: u32,
}
impl Default for PlotConfig {
fn default() -> Self {
Self {
width: 800,
height: 600,
caption: "Plot".to_string(),
line_color: RED.to_rgba(),
mesh_color: BLACK.mix(0.1),
samples: 500,
pitch: 0.5,
yaw: 0.5,
scale: 0.7,
label_font_size: 20,
caption_font_size: 40,
}
}
}
pub(crate) fn eval_expr(
root_expr: &Expr,
vars: &HashMap<String, f64>,
) -> Result<f64, String> {
let mut results: HashMap<Expr, f64> = HashMap::new();
let mut stack: Vec<Expr> = vec![root_expr.clone()];
let mut visited = std::collections::HashSet::new();
while let Some(expr) = stack.last() {
if results.contains_key(expr) {
stack.pop();
continue;
}
let children = expr.children();
if children.is_empty() || visited.contains(expr) {
let current_expr = stack.pop().expect("Expr present");
let children = current_expr.children();
let get_child_val = |i: usize| -> f64 { results[&children[i]] };
let val_result = match current_expr.op() {
| DagOp::Constant(c) => Ok(c.into_inner()),
| DagOp::BigInt(i) => {
i.to_f64()
.ok_or_else(|| "BigInt conversion to f64 failed".to_string())
},
| DagOp::Rational(r) => {
Ok(r.numer().to_f64().unwrap() / r.denom().to_f64().unwrap())
},
| DagOp::Variable(v) => {
vars.get(&v)
.copied()
.ok_or_else(|| format!("Variable '{v}' not found"))
},
| DagOp::Add => Ok(get_child_val(0) + get_child_val(1)),
| DagOp::Sub => Ok(get_child_val(0) - get_child_val(1)),
| DagOp::Mul => Ok(get_child_val(0) * get_child_val(1)),
| DagOp::Div => Ok(get_child_val(0) / get_child_val(1)),
| DagOp::Power => Ok(get_child_val(0).powf(get_child_val(1))),
| DagOp::Neg => Ok(-get_child_val(0)),
| DagOp::Sqrt => Ok(get_child_val(0).sqrt()),
| DagOp::Abs => Ok(get_child_val(0).abs()),
| DagOp::Sin => Ok(get_child_val(0).sin()),
| DagOp::Cos => Ok(get_child_val(0).cos()),
| DagOp::Tan => Ok(get_child_val(0).tan()),
| DagOp::Csc => Ok(1.0 / get_child_val(0).sin()),
| DagOp::Sec => Ok(1.0 / get_child_val(0).cos()),
| DagOp::Cot => Ok(1.0 / get_child_val(0).tan()),
| DagOp::ArcSin => Ok(get_child_val(0).asin()),
| DagOp::ArcCos => Ok(get_child_val(0).acos()),
| DagOp::ArcTan => Ok(get_child_val(0).atan()),
| DagOp::Sinh => Ok(get_child_val(0).sinh()),
| DagOp::Cosh => Ok(get_child_val(0).cosh()),
| DagOp::Tanh => Ok(get_child_val(0).tanh()),
| DagOp::Log => Ok(get_child_val(0).ln()),
| DagOp::LogBase => Ok(get_child_val(1).log(get_child_val(0))),
| DagOp::Exp => Ok(get_child_val(0).exp()),
| DagOp::Floor => Ok(get_child_val(0).floor()),
| DagOp::Pi => Ok(std::f64::consts::PI),
| DagOp::E => Ok(std::f64::consts::E),
| _ => {
Err(format!(
"Numerical evaluation for operation {:?} is not implemented",
current_expr.op()
))
},
};
let val = val_result?;
results.insert(current_expr, val);
} else {
visited.insert(expr.clone());
for child in children.iter().rev() {
stack.push(child.clone());
}
}
}
Ok(results[root_expr])
}
pub fn plot_function_2d(
expr: &Expr,
var: &str,
range: (f64, f64),
path: &str,
config: Option<PlotConfig>,
) -> Result<(), String> {
let conf = config.unwrap_or_default();
let root = BitMapBackend::new(path, (conf.width, conf.height)).into_drawing_area();
root.fill(&WHITE).map_err(|e| e.to_string())?;
let y_min = (0..100)
.map(|i| {
let x = (range.1 - range.0).mul_add(f64::from(i) / 99.0, range.0);
eval_expr(expr, &HashMap::from([(var.to_string(), x)]))
})
.filter_map(Result::ok)
.fold(f64::INFINITY, f64::min);
let y_max = (0..100)
.map(|i| {
let x = (range.1 - range.0).mul_add(f64::from(i) / 99.0, range.0);
eval_expr(expr, &HashMap::from([(var.to_string(), x)]))
})
.filter_map(Result::ok)
.fold(f64::NEG_INFINITY, f64::max);
let mut chart = ChartBuilder::on(&root)
.caption(&conf.caption, ("sans-serif", 40).into_font())
.margin(5)
.x_label_area_size(30)
.y_label_area_size(30)
.build_cartesian_2d(range.0..range.1, y_min..y_max)
.map_err(|e| e.to_string())?;
chart
.configure_mesh()
.light_line_style(conf.mesh_color)
.draw()
.map_err(|e| e.to_string())?;
chart
.draw_series(LineSeries::new(
(0..=conf.samples).map(|i| {
let x = (range.1 - range.0).mul_add((i as f64) / conf.samples as f64, range.0);
let y = eval_expr(expr, &HashMap::from([(var.to_string(), x)])).unwrap_or(0.0);
(x, y)
}),
&conf.line_color,
))
.map_err(|e| e.to_string())?;
root.present().map_err(|e| e.to_string())?;
Ok(())
}
pub fn plot_series_2d(
series: &[(String, Vec<(f64, f64)>)],
path: &str,
config: Option<PlotConfig>,
) -> Result<(), String> {
let conf = config.unwrap_or_default();
let root = BitMapBackend::new(path, (conf.width, conf.height)).into_drawing_area();
root.fill(&WHITE).map_err(|e| e.to_string())?;
if series.is_empty() {
return Err("No data series \
provided"
.to_string());
}
let mut x_min = f64::INFINITY;
let mut x_max = f64::NEG_INFINITY;
let mut y_min = f64::INFINITY;
let mut y_max = f64::NEG_INFINITY;
for (_, data) in series {
for &(x, y) in data {
x_min = x_min.min(x);
x_max = x_max.max(x);
y_min = y_min.min(y);
y_max = y_max.max(y);
}
}
let x_pad = (x_max - x_min) * 0.05;
let y_pad = (y_max - y_min) * 0.1;
let x_range = (x_min - x_pad)..(x_max + x_pad);
let y_range = (y_min - y_pad)..(y_max + y_pad);
let mut chart = ChartBuilder::on(&root)
.caption(
&conf.caption,
("sans-serif", conf.caption_font_size).into_font(),
)
.margin(conf.label_font_size / 2)
.x_label_area_size(conf.label_font_size * 2)
.y_label_area_size(conf.label_font_size * 2)
.build_cartesian_2d(x_range, y_range)
.map_err(|e| e.to_string())?;
chart
.configure_mesh()
.light_line_style(conf.mesh_color)
.label_style(("sans-serif", conf.label_font_size).into_font())
.draw()
.map_err(|e| e.to_string())?;
let colors = [&RED, &BLUE, &GREEN, &CYAN, &MAGENTA, &YELLOW, &BLACK];
for (i, (label, data)) in series.iter().enumerate() {
let color = colors[i % colors.len()];
chart
.draw_series(LineSeries::new(data.clone(), color))
.map_err(|e| e.to_string())?
.label(label)
.legend(move |(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], color));
}
chart
.configure_series_labels()
.background_style(WHITE.mix(0.8))
.border_style(BLACK)
.label_font(("sans-serif", conf.label_font_size).into_font())
.draw()
.map_err(|e| e.to_string())?;
root.present().map_err(|e| e.to_string())?;
Ok(())
}
pub fn plot_vector_field_2d(
comps: (&Expr, &Expr),
vars: (&str, &str),
x_range: (f64, f64),
y_range: (f64, f64),
path: &str,
config: Option<PlotConfig>,
) -> Result<(), String> {
let conf = config.unwrap_or_default();
let root = BitMapBackend::new(path, (conf.width, conf.height)).into_drawing_area();
root.fill(&WHITE).map_err(|e| e.to_string())?;
let mut chart = ChartBuilder::on(&root)
.caption(&conf.caption, ("sans-serif", 40).into_font())
.build_cartesian_2d(x_range.0..x_range.1, y_range.0..y_range.1)
.map_err(|e| e.to_string())?;
chart
.configure_mesh()
.light_line_style(conf.mesh_color)
.draw()
.map_err(|e| e.to_string())?;
let (vx_expr, vy_expr) = comps;
let (x_var, y_var) = vars;
let mut arrows = Vec::new();
let steps: usize = ((conf.samples as f64).sqrt().round() as i64)
.try_into()
.unwrap_or(0);
for i in 0..steps {
for j in 0..steps {
let x = (x_range.1 - x_range.0).mul_add((i as f64) / (steps - 1) as f64, x_range.0);
let y = (y_range.1 - y_range.0).mul_add((j as f64) / (steps - 1) as f64, y_range.0);
let mut vars_map = HashMap::new();
vars_map.insert(x_var.to_string(), x);
vars_map.insert(y_var.to_string(), y);
if let (Ok(vx), Ok(vy)) = (eval_expr(vx_expr, &vars_map), eval_expr(vy_expr, &vars_map))
{
let magnitude = vx.hypot(vy);
if magnitude > 1e-9 {
let scale = (x_range.1 - x_range.0) * 0.05;
let end_x = (vx / magnitude).mul_add(scale, x);
let end_y = (vy / magnitude).mul_add(scale, y);
arrows.push(PathElement::new(
vec![(x, y), (end_x, end_y)],
conf.line_color,
));
}
}
}
}
chart.draw_series(arrows).map_err(|e| e.to_string())?;
root.present().map_err(|e| e.to_string())?;
Ok(())
}
pub fn plot_surface_3d(
expr: &Expr,
vars: (&str, &str),
x_range: (f64, f64),
y_range: (f64, f64),
path: &str,
config: Option<PlotConfig>,
) -> Result<(), String> {
let conf = config.unwrap_or_default();
let root = BitMapBackend::new(path, (conf.width, conf.height)).into_drawing_area();
root.fill(&WHITE).map_err(|e| e.to_string())?;
let mut chart = ChartBuilder::on(&root)
.caption(&conf.caption, ("sans-serif", 40).into_font())
.build_cartesian_3d(x_range.0..x_range.1, -1.0..1.0, y_range.0..y_range.1)
.map_err(|e| e.to_string())?;
chart.configure_axes().draw().map_err(|e| e.to_string())?;
let (x_var, y_var) = vars;
let steps: usize = ((conf.samples as f64).sqrt().round() as i64)
.try_into()
.unwrap_or(0);
let _ = chart.draw_series(
SurfaceSeries::xoz(
(0..steps)
.map(|i| x_range.0 + (x_range.1 - x_range.0) * (i as f64) / (steps - 1) as f64),
(0..steps)
.map(|i| y_range.0 + (y_range.1 - y_range.0) * (i as f64) / (steps - 1) as f64),
|x, z| {
let mut vars_map = HashMap::new();
vars_map.insert(x_var.to_string(), x);
vars_map.insert(y_var.to_string(), z);
eval_expr(expr, &vars_map).unwrap_or(0.0)
},
)
.style(conf.line_color.mix(0.5).filled()),
);
root.present().map_err(|e| e.to_string())?;
Ok(())
}
pub fn plot_surface_2d(
data: &Array2<f64>,
path: &str,
config: Option<PlotConfig>,
) -> Result<(), String> {
let conf = config.unwrap_or_default();
let root = BitMapBackend::new(path, (conf.width, conf.height)).into_drawing_area();
root.fill(&WHITE).map_err(|e| e.to_string())?;
let (height, width) = data.dim();
let mut min_val = f64::INFINITY;
let mut max_val = f64::NEG_INFINITY;
for &val in data {
min_val = min_val.min(val);
max_val = max_val.max(val);
}
let z_min = if (max_val - min_val).abs() < 1e-9 {
min_val - 1.0
} else {
(max_val - min_val).mul_add(-0.1, min_val)
};
let z_max = if (max_val - min_val).abs() < 1e-9 {
max_val + 1.0
} else {
(max_val - min_val).mul_add(0.1, max_val)
};
let mut chart = ChartBuilder::on(&root)
.caption(&conf.caption, ("sans-serif", 40).into_font())
.build_cartesian_3d(0.0..width as f64, z_min..z_max, 0.0..height as f64)
.map_err(|e| e.to_string())?;
chart.configure_axes().draw().map_err(|e| e.to_string())?;
chart
.draw_series(
SurfaceSeries::xoz(
(0..width).map(|x| x as f64),
(0..height).map(|y| y as f64),
|x, y| {
let ix = usize::try_from(x.round() as isize)
.unwrap_or(0)
.min(width - 1);
let iy = usize::try_from(y.round() as isize)
.unwrap_or(0)
.min(height - 1);
data[[iy, ix]]
},
)
.style(conf.line_color.mix(0.5).filled()),
)
.map_err(|e| e.to_string())?;
root.present().map_err(|e| e.to_string())?;
Ok(())
}
pub fn plot_parametric_curve_3d(
comps: (&Expr, &Expr, &Expr),
var: &str,
range: (f64, f64),
path: &str,
config: Option<PlotConfig>,
) -> Result<(), String> {
let conf = config.unwrap_or_default();
let root = BitMapBackend::new(path, (conf.width, conf.height)).into_drawing_area();
root.fill(&WHITE).map_err(|e| e.to_string())?;
let mut chart = ChartBuilder::on(&root)
.caption(&conf.caption, ("sans-serif", 40).into_font())
.build_cartesian_3d(-3.0..3.0, -3.0..3.0, -3.0..3.0)
.map_err(|e| e.to_string())?;
chart.configure_axes().draw().map_err(|e| e.to_string())?;
let (x_expr, y_expr, z_expr) = comps;
chart
.draw_series(LineSeries::new(
(0..=conf.samples).map(|i| {
let t = (range.1 - range.0).mul_add((i as f64) / conf.samples as f64, range.0);
let mut vars_map = HashMap::new();
vars_map.insert(var.to_string(), t);
let x = eval_expr(x_expr, &vars_map).unwrap_or(0.0);
let y = eval_expr(y_expr, &vars_map).unwrap_or(0.0);
let z = eval_expr(z_expr, &vars_map).unwrap_or(0.0);
(x, y, z)
}),
&conf.line_color,
))
.map_err(|e| e.to_string())?;
root.present().map_err(|e| e.to_string())?;
Ok(())
}
pub fn plot_vector_field_3d(
comps: (&Expr, &Expr, &Expr),
vars: (&str, &str, &str),
ranges: ((f64, f64), (f64, f64), (f64, f64)),
path: &str,
config: Option<PlotConfig>,
) -> Result<(), String> {
let conf = config.unwrap_or_default();
let root = BitMapBackend::new(path, (conf.width, conf.height)).into_drawing_area();
root.fill(&WHITE).map_err(|e| e.to_string())?;
let (x_range, y_range, z_range) = ranges;
let mut chart = ChartBuilder::on(&root)
.caption(&conf.caption, ("sans-serif", 40).into_font())
.build_cartesian_3d(
x_range.0..x_range.1,
y_range.0..y_range.1,
z_range.0..z_range.1,
)
.map_err(|e| e.to_string())?;
chart.configure_axes().draw().map_err(|e| e.to_string())?;
let (vx_expr, vy_expr, vz_expr) = comps;
let (x_var, y_var, z_var) = vars;
let mut arrows = Vec::new();
let steps: usize = ((conf.samples as f64).cbrt().round() as i64)
.try_into()
.unwrap_or(0);
for i in 0..steps {
for j in 0..steps {
for k in 0..steps {
let x = (x_range.1 - x_range.0).mul_add((i as f64) / (steps - 1) as f64, x_range.0);
let y = (y_range.1 - y_range.0).mul_add((j as f64) / (steps - 1) as f64, y_range.0);
let z = (z_range.1 - z_range.0).mul_add((k as f64) / (steps - 1) as f64, z_range.0);
let mut vars_map = HashMap::new();
vars_map.insert(x_var.to_string(), x);
vars_map.insert(y_var.to_string(), y);
vars_map.insert(z_var.to_string(), z);
if let (Ok(vx), Ok(vy), Ok(vz)) = (
eval_expr(vx_expr, &vars_map),
eval_expr(vy_expr, &vars_map),
eval_expr(vz_expr, &vars_map),
) {
let magnitude = vz.mul_add(vz, vx.mul_add(vx, vy * vy)).sqrt();
if magnitude > 1e-6 {
let scale = (x_range.1 - x_range.0) * 0.05;
let end_x = (vx / magnitude).mul_add(scale, x);
let end_y = (vy / magnitude).mul_add(scale, y);
let end_z = (vz / magnitude).mul_add(scale, z);
arrows.push(PathElement::new(
vec![(x, y, z), (end_x, end_y, end_z)],
conf.line_color,
));
}
}
}
}
}
chart.draw_series(arrows).map_err(|e| e.to_string())?;
root.present().map_err(|e| e.to_string())?;
Ok(())
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::prelude::Expr;
#[test]
fn test_eval_basic() {
let x = Expr::Variable("x".to_string());
let expr = Expr::new_add(x, Expr::Constant(2.0));
let mut vars = HashMap::new();
vars.insert("x".to_string(), 3.0);
let res = eval_expr(&expr, &vars).unwrap();
assert_eq!(res, 5.0);
}
#[test]
fn test_eval_trig() {
let x = Expr::Variable("x".to_string());
let expr = Expr::new_sin(x);
let mut vars = HashMap::new();
vars.insert("x".to_string(), std::f64::consts::PI / 2.0);
let res = eval_expr(&expr, &vars).unwrap();
assert!((res - 1.0).abs() < 1e-10);
}
#[test]
fn test_eval_log_power() {
let x = Expr::Variable("x".to_string());
let expr = Expr::new_exp(Expr::new_log(x.clone()));
let mut vars = HashMap::new();
vars.insert("x".to_string(), 5.0);
let res = eval_expr(&expr, &vars).unwrap();
assert!((res - 5.0).abs() < 1e-10);
}
#[test]
fn test_eval_error() {
let x = Expr::Variable("x".to_string());
let vars = HashMap::new(); let res = eval_expr(&x, &vars);
assert!(res.is_err());
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_no_panic_eval(
val in -100.0..100.0f64,
depth in 1..4usize,
) {
let x = Expr::Variable("x".to_string());
let mut expr = x.clone();
for _ in 0..depth {
expr = Expr::new_add(expr.clone(), Expr::Constant(1.0));
expr = Expr::new_mul(expr.clone(), Expr::Constant(0.5));
}
let mut vars = HashMap::new();
vars.insert("x".to_string(), val);
let _ = eval_expr(&expr, &vars);
}
}
}
pub fn plot_3d_path_from_points(
data: &[Vec<f64>],
path: &str,
config: Option<PlotConfig>,
) -> Result<(), String> {
let conf = config.unwrap_or_default();
let root = BitMapBackend::new(path, (conf.width, conf.height)).into_drawing_area();
root.fill(&WHITE).map_err(|e| e.to_string())?;
let (mut x_min, mut x_max) = (f64::INFINITY, f64::NEG_INFINITY);
let (mut y_min, mut y_max) = (f64::INFINITY, f64::NEG_INFINITY);
let (mut z_min, mut z_max) = (f64::INFINITY, f64::NEG_INFINITY);
for point in data {
if point.len() >= 3 {
x_min = x_min.min(point[0]);
x_max = x_max.max(point[0]);
y_min = y_min.min(point[1]);
y_max = y_max.max(point[1]);
z_min = z_min.min(point[2]);
z_max = z_max.max(point[2]);
}
}
let mut chart = ChartBuilder::on(&root)
.caption(&conf.caption, ("sans-serif", 40).into_font())
.build_cartesian_3d(x_min..x_max, z_min..z_max, y_min..y_max) .map_err(|e| e.to_string())?;
chart.configure_axes().draw().map_err(|e| e.to_string())?;
chart
.draw_series(LineSeries::new(
data.iter().filter_map(|p| {
if p.len() >= 3 {
Some((p[0], p[2], p[1]))
} else {
None
}
}), &conf.line_color,
))
.map_err(|e| e.to_string())?;
root.present().map_err(|e| e.to_string())?;
Ok(())
}
pub fn plot_heatmap_2d(
data: &Array2<f64>,
path: &str,
config: Option<PlotConfig>,
) -> Result<(), String> {
let conf = config.unwrap_or_default();
let root = BitMapBackend::new(path, (conf.width, conf.height)).into_drawing_area();
root.fill(&WHITE).map_err(|e| e.to_string())?;
let (height, width) = data.dim();
let mut min_val = f64::INFINITY;
let mut max_val = f64::NEG_INFINITY;
for &val in data {
min_val = min_val.min(val);
println!("min_val: {min_val}");
max_val = max_val.max(val);
println!("max_val: {max_val}");
}
let mut chart = ChartBuilder::on(&root)
.caption(&conf.caption, ("sans-serif", 40).into_font())
.build_cartesian_2d(0..width as u32, 0..height as u32)
.map_err(|e| e.to_string())?;
chart.configure_mesh().draw().map_err(|e| e.to_string())?;
chart
.draw_series((0..width).flat_map(move |x| {
(0..height).map(move |y| {
let val = data[[y, x]]; let normalized = if max_val - min_val > 0.0 {
(val - min_val) / (max_val - min_val)
} else {
0.5 };
let color = HSLColor(240.0 * normalized, 1.0, 0.5);
Rectangle::new(
[(x as u32, y as u32), (x as u32 + 1, y as u32 + 1)],
color.filled(),
)
})
}))
.map_err(|e| e.to_string())?;
root.present().map_err(|e| e.to_string())?;
Ok(())
}