use crate::error::{InferustError, Result};
use std::fmt::Write as FmtWrite;
const COLORS: &[&str] = &[
"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
];
fn color(idx: usize) -> &'static str {
COLORS[idx % COLORS.len()]
}
#[derive(Debug, Clone)]
enum Series {
Line {
x: Vec<f64>,
y: Vec<f64>,
label: String,
color: String,
},
Scatter {
x: Vec<f64>,
y: Vec<f64>,
label: String,
color: String,
},
Bar {
x: Vec<f64>,
heights: Vec<f64>,
label: String,
color: String,
},
Band {
x: Vec<f64>,
lo: Vec<f64>,
hi: Vec<f64>,
color: String,
},
Step {
x: Vec<f64>,
y: Vec<f64>,
label: String,
color: String,
},
HLine {
y: f64,
color: String,
dash: bool,
},
}
#[derive(Debug, Clone, Default)]
pub struct Plot {
title: String,
xlabel: String,
ylabel: String,
series: Vec<Series>,
width: f64,
height: f64,
}
impl Plot {
pub fn new() -> Self {
Self {
width: 600.0,
height: 350.0,
..Default::default()
}
}
pub fn title(mut self, t: impl Into<String>) -> Self {
self.title = t.into();
self
}
pub fn xlabel(mut self, l: impl Into<String>) -> Self {
self.xlabel = l.into();
self
}
pub fn ylabel(mut self, l: impl Into<String>) -> Self {
self.ylabel = l.into();
self
}
pub fn width(mut self, w: f64) -> Self {
self.width = w;
self
}
pub fn height(mut self, h: f64) -> Self {
self.height = h;
self
}
pub fn line(mut self, x: &[f64], y: &[f64], label: impl Into<String>) -> Self {
let idx = self.series.len();
self.series.push(Series::Line {
x: x.to_vec(),
y: y.to_vec(),
label: label.into(),
color: color(idx).into(),
});
self
}
pub fn scatter(mut self, x: &[f64], y: &[f64], label: impl Into<String>) -> Self {
let idx = self.series.len();
self.series.push(Series::Scatter {
x: x.to_vec(),
y: y.to_vec(),
label: label.into(),
color: color(idx).into(),
});
self
}
pub fn bar(mut self, x: &[f64], heights: &[f64], label: impl Into<String>) -> Self {
let idx = self.series.len();
self.series.push(Series::Bar {
x: x.to_vec(),
heights: heights.to_vec(),
label: label.into(),
color: color(idx).into(),
});
self
}
pub fn step(mut self, x: &[f64], y: &[f64], label: impl Into<String>) -> Self {
let idx = self.series.len();
self.series.push(Series::Step {
x: x.to_vec(),
y: y.to_vec(),
label: label.into(),
color: color(idx).into(),
});
self
}
pub fn band(mut self, x: &[f64], lo: &[f64], hi: &[f64]) -> Self {
let idx = self.series.len();
self.series.push(Series::Band {
x: x.to_vec(),
lo: lo.to_vec(),
hi: hi.to_vec(),
color: color(idx).into(),
});
self
}
pub fn hline(mut self, y: f64, dashed: bool) -> Self {
self.series.push(Series::HLine {
y,
color: "#aaaaaa".into(),
dash: dashed,
});
self
}
pub fn acf(lags: &[usize], values: &[f64], conf_bound: f64) -> Self {
let x: Vec<f64> = lags.iter().map(|&l| l as f64).collect();
Plot::new()
.title("ACF")
.xlabel("lag")
.ylabel("autocorrelation")
.bar(&x, values, "acf")
.hline(conf_bound, true)
.hline(-conf_bound, true)
.hline(0.0, false)
}
pub fn survival(curve: &[crate::survival::KmStep]) -> Self {
let times: Vec<f64> = curve.iter().map(|s| s.time).collect();
let surv: Vec<f64> = curve.iter().map(|s| s.survival).collect();
let lo: Vec<f64> = curve.iter().map(|s| s.ci_lower).collect();
let hi: Vec<f64> = curve.iter().map(|s| s.ci_upper).collect();
Plot::new()
.title("Kaplan-Meier Survival Curve")
.xlabel("time")
.ylabel("S(t)")
.band(×, &lo, &hi)
.step(×, &surv, "survival")
}
pub fn residuals(fitted: &[f64], residuals: &[f64]) -> Self {
Plot::new()
.title("Residuals vs Fitted")
.xlabel("fitted values")
.ylabel("residuals")
.scatter(fitted, residuals, "residuals")
.hline(0.0, true)
}
pub fn to_svg(&self) -> String {
let w = self.width;
let h = self.height;
let (mt, mr, mb, ml) = (
if self.title.is_empty() { 15.0 } else { 35.0 },
20.0,
if self.xlabel.is_empty() { 40.0 } else { 55.0 },
if self.ylabel.is_empty() { 45.0 } else { 60.0 },
);
let pw = w - ml - mr; let ph = h - mt - mb;
let (xmin, xmax, ymin, ymax) = self.data_bounds();
let xrange = (xmax - xmin).max(f64::EPSILON);
let yrange = (ymax - ymin).max(f64::EPSILON);
let sx = |xv: f64| ml + (xv - xmin) / xrange * pw;
let sy = |yv: f64| mt + ph - (yv - ymin) / yrange * ph;
let mut svg = String::new();
let _ = write!(
svg,
r#"<svg viewBox="0 0 {w} {h}" xmlns="http://www.w3.org/2000/svg" font-family="sans-serif">"#
);
let _ = write!(svg, r#"<rect width="{w}" height="{h}" fill="white"/>"#);
let n_yticks = 5usize;
for i in 0..=n_yticks {
let yv = ymin + yrange * i as f64 / n_yticks as f64;
let yp = sy(yv);
let _ = write!(
svg,
r##"<line x1="{ml:.1}" y1="{yp:.1}" x2="{:.1}" y2="{yp:.1}" stroke="#e0e0e0" stroke-width="0.5"/>"##,
ml + pw
);
let _ = write!(
svg,
r##"<text x="{:.1}" y="{:.1}" text-anchor="end" font-size="11" fill="#666">{:.2}</text>"##,
ml - 4.0,
yp + 4.0,
yv
);
}
let n_xticks = 6usize;
for i in 0..=n_xticks {
let xv = xmin + xrange * i as f64 / n_xticks as f64;
let xp = sx(xv);
let _ = write!(
svg,
r##"<line x1="{xp:.1}" y1="{mt:.1}" x2="{xp:.1}" y2="{:.1}" stroke="#e0e0e0" stroke-width="0.5"/>"##,
mt + ph
);
let _ = write!(
svg,
r##"<text x="{xp:.1}" y="{:.1}" text-anchor="middle" font-size="11" fill="#666">{:.2}</text>"##,
mt + ph + 14.0,
xv
);
}
let _ = write!(
svg,
r##"<rect x="{ml:.1}" y="{mt:.1}" width="{pw:.1}" height="{ph:.1}" fill="none" stroke="#cccccc" stroke-width="0.8"/>"##
);
let _ = write!(
svg,
r#"<clipPath id="plot-area"><rect x="{ml:.1}" y="{mt:.1}" width="{pw:.1}" height="{ph:.1}"/></clipPath>"#
);
for s in &self.series {
self.render_series(&mut svg, s, &sx, &sy, pw);
}
if !self.xlabel.is_empty() {
let _ = write!(
svg,
r##"<text x="{:.1}" y="{:.1}" text-anchor="middle" font-size="12" fill="#444">{}</text>"##,
ml + pw / 2.0,
h - 5.0,
escape_xml(&self.xlabel)
);
}
if !self.ylabel.is_empty() {
let _ = write!(
svg,
r##"<text transform="rotate(-90)" x="{:.1}" y="{:.1}" text-anchor="middle" font-size="12" fill="#444">{}</text>"##,
-(mt + ph / 2.0),
13.0,
escape_xml(&self.ylabel)
);
}
if !self.title.is_empty() {
let _ = write!(
svg,
r##"<text x="{:.1}" y="{:.1}" text-anchor="middle" font-size="14" font-weight="500" fill="#222">{}</text>"##,
w / 2.0,
mt - 8.0,
escape_xml(&self.title)
);
}
let legend_series: Vec<(&str, &str)> = self
.series
.iter()
.filter_map(|s| match s {
Series::Line { label, color, .. } => Some((label.as_str(), color.as_str())),
Series::Scatter { label, color, .. } => Some((label.as_str(), color.as_str())),
Series::Bar { label, color, .. } => Some((label.as_str(), color.as_str())),
Series::Step { label, color, .. } => Some((label.as_str(), color.as_str())),
_ => None,
})
.collect();
if !legend_series.is_empty() {
let lx = ml + pw - 10.0;
let mut ly = mt + 10.0;
for (label, col) in &legend_series {
let _ = write!(
svg,
r#"<rect x="{:.1}" y="{:.1}" width="12" height="3" fill="{col}" rx="1"/>"#,
lx - 18.0,
ly + 4.0
);
let _ = write!(
svg,
r##"<text x="{:.1}" y="{:.1}" text-anchor="end" font-size="11" fill="#444">{}</text>"##,
lx - 22.0,
ly + 10.0,
escape_xml(label)
);
ly += 16.0;
}
}
svg.push_str("</svg>");
svg
}
pub fn save(&self, path: &str) -> Result<()> {
std::fs::write(path, self.to_svg())
.map_err(|e| InferustError::InvalidInput(format!("failed to write {path}: {e}")))
}
pub fn print_ascii(&self) {
const COLS: usize = 70;
const ROWS: usize = 20;
let (xmin, xmax, ymin, ymax) = self.data_bounds();
let xrange = (xmax - xmin).max(f64::EPSILON);
let yrange = (ymax - ymin).max(f64::EPSILON);
let mut grid = vec![vec![' '; COLS]; ROWS];
let col_of = |x: f64| {
((x - xmin) / xrange * (COLS - 1) as f64)
.round()
.clamp(0.0, (COLS - 1) as f64) as usize
};
let row_of = |y: f64| {
(ROWS - 1)
- ((y - ymin) / yrange * (ROWS - 1) as f64)
.round()
.clamp(0.0, (ROWS - 1) as f64) as usize
};
if ymin <= 0.0 && ymax >= 0.0 {
let r = row_of(0.0);
for cell in &mut grid[r] {
*cell = '-';
}
}
for s in &self.series {
match s {
Series::Line { x, y, .. } | Series::Step { x, y, .. } => {
for (xv, yv) in x.iter().zip(y.iter()) {
let r = row_of(*yv);
let c = col_of(*xv);
grid[r][c] = '*';
}
}
Series::Scatter { x, y, .. } => {
for (xv, yv) in x.iter().zip(y.iter()) {
let r = row_of(*yv);
let c = col_of(*xv);
grid[r][c] = 'o';
}
}
Series::Bar { x, heights, .. } => {
let baseline = row_of(0.0_f64.clamp(ymin, ymax));
for (xv, hv) in x.iter().zip(heights.iter()) {
let top = row_of(*hv);
let c = col_of(*xv);
let (r0, r1) = if top <= baseline {
(top, baseline)
} else {
(baseline, top)
};
for row in grid[r0..=r1.min(ROWS - 1)].iter_mut() {
row[c] = '|';
}
if top < ROWS {
grid[top][c] = '#';
}
}
}
Series::HLine { y, .. } => {
if *y >= ymin && *y <= ymax {
let r = row_of(*y);
for cell in &mut grid[r] {
if *cell == ' ' {
*cell = '.';
}
}
}
}
Series::Band { .. } => {} }
}
if !self.title.is_empty() {
println!("{:^width$}", self.title, width = COLS + 8);
}
for (i, row) in grid.iter().enumerate() {
let yv = ymax - (i as f64 / (ROWS - 1) as f64) * yrange;
if i % 5 == 0 {
print!("{:>7.2} |", yv);
} else {
print!(" |");
}
let line: String = row.iter().collect();
println!("{line}");
}
println!(" +{}", "-".repeat(COLS));
let ticks = 7usize;
print!(" ");
for i in 0..=ticks {
let xv = xmin + xrange * i as f64 / ticks as f64;
let _pos = (i * (COLS / ticks)).min(COLS - 1);
let s = format!("{xv:.1}");
print!(
"{:<width$}",
s,
width = if i < ticks { COLS / ticks } else { s.len() }
);
}
println!();
if !self.xlabel.is_empty() {
println!("{:^width$}", self.xlabel, width = COLS + 9);
}
}
fn data_bounds(&self) -> (f64, f64, f64, f64) {
let mut xmin = f64::INFINITY;
let mut xmax = f64::NEG_INFINITY;
let mut ymin = f64::INFINITY;
let mut ymax = f64::NEG_INFINITY;
for s in &self.series {
let (xs, ys): (Option<&[f64]>, Option<&[f64]>) = match s {
Series::Line { x, y, .. } => (Some(x), Some(y)),
Series::Scatter { x, y, .. } => (Some(x), Some(y)),
Series::Bar { x, heights, .. } => (Some(x), Some(heights)),
Series::Step { x, y, .. } => (Some(x), Some(y)),
Series::Band { x, lo, hi, .. } => {
for v in x.iter() {
xmin = xmin.min(*v);
xmax = xmax.max(*v);
}
for v in lo.iter() {
ymin = ymin.min(*v);
}
for v in hi.iter() {
ymax = ymax.max(*v);
}
(None, None)
}
Series::HLine { y, .. } => {
ymin = ymin.min(*y);
ymax = ymax.max(*y);
(None, None)
}
};
if let Some(xs) = xs {
for v in xs {
xmin = xmin.min(*v);
xmax = xmax.max(*v);
}
}
if let Some(ys) = ys {
for v in ys {
ymin = ymin.min(*v);
ymax = ymax.max(*v);
}
}
}
if !xmin.is_finite() {
xmin = 0.0;
xmax = 1.0;
}
if !ymin.is_finite() {
ymin = 0.0;
ymax = 1.0;
}
if (xmax - xmin).abs() < f64::EPSILON {
xmax = xmin + 1.0;
}
if (ymax - ymin).abs() < f64::EPSILON {
ymin -= 0.5;
ymax += 0.5;
}
let ypad = (ymax - ymin) * 0.05;
(xmin, xmax, ymin - ypad, ymax + ypad)
}
fn render_series(
&self,
svg: &mut String,
s: &Series,
sx: &impl Fn(f64) -> f64,
sy: &impl Fn(f64) -> f64,
pw: f64,
) {
match s {
Series::Line { x, y, color, .. } => {
if x.is_empty() {
return;
}
let _ = write!(
svg,
r#"<polyline clip-path="url(#plot-area)" fill="none" stroke="{color}" stroke-width="1.8" points=""#
);
for (xv, yv) in x.iter().zip(y.iter()) {
let _ = write!(svg, "{:.1},{:.1} ", sx(*xv), sy(*yv));
}
svg.push_str(r#""/>"#);
}
Series::Step { x, y, color, .. } => {
if x.is_empty() {
return;
}
let _ = write!(
svg,
r#"<path clip-path="url(#plot-area)" fill="none" stroke="{color}" stroke-width="1.8" d=""#
);
let _ = write!(svg, "M {:.1},{:.1} ", sx(x[0]), sy(y[0]));
for i in 1..x.len() {
let _ = write!(svg, "H {:.1} V {:.1} ", sx(x[i]), sy(y[i]));
}
svg.push_str(r#""/>"#);
}
Series::Scatter { x, y, color, .. } => {
for (xv, yv) in x.iter().zip(y.iter()) {
let _ = write!(
svg,
r#"<circle clip-path="url(#plot-area)" cx="{:.1}" cy="{:.1}" r="3" fill="{color}" fill-opacity="0.7"/>"#,
sx(*xv),
sy(*yv)
);
}
}
Series::Bar {
x, heights, color, ..
} => {
let n = x.len();
let bar_w = if n > 1 {
(pw / n as f64 * 0.7).max(2.0)
} else {
20.0
};
let (_, _, ymin, _) = self.data_bounds();
let baseline_y = sy(0.0_f64.max(ymin));
for (xv, hv) in x.iter().zip(heights.iter()) {
let cx = sx(*xv);
let top_y = sy(*hv);
let bh = (baseline_y - top_y).abs().max(1.0);
let rect_y = top_y.min(baseline_y);
let _ = write!(
svg,
r#"<rect clip-path="url(#plot-area)" x="{:.1}" y="{:.1}" width="{:.1}" height="{:.1}" fill="{color}" fill-opacity="0.8"/>"#,
cx - bar_w / 2.0,
rect_y,
bar_w,
bh
);
}
}
Series::Band {
x, lo, hi, color, ..
} => {
if x.is_empty() {
return;
}
let _ = write!(
svg,
r#"<polygon clip-path="url(#plot-area)" fill="{color}" fill-opacity="0.15" points=""#
);
for (xv, hv) in x.iter().zip(hi.iter()) {
let _ = write!(svg, "{:.1},{:.1} ", sx(*xv), sy(*hv));
}
for (xv, lv) in x.iter().rev().zip(lo.iter().rev()) {
let _ = write!(svg, "{:.1},{:.1} ", sx(*xv), sy(*lv));
}
svg.push_str(r#""/>"#);
}
Series::HLine { y, color, dash } => {
let (xmin, xmax, _, _) = self.data_bounds();
let yp = sy(*y);
let dash_attr = if *dash {
r#" stroke-dasharray="4,3""#
} else {
""
};
let _ = write!(
svg,
r#"<line clip-path="url(#plot-area)" x1="{:.1}" y1="{yp:.1}" x2="{:.1}" y2="{yp:.1}" stroke="{color}" stroke-width="1"{dash_attr}/>"#,
sx(xmin),
sx(xmax)
);
}
}
}
}
fn escape_xml(s: &str) -> String {
s.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
.replace('"', """)
}
#[cfg(test)]
mod tests {
use super::Plot;
#[test]
fn svg_contains_polyline_for_line_series() {
let x = vec![0.0, 1.0, 2.0, 3.0];
let y = vec![0.0, 1.0, 0.5, 2.0];
let svg = Plot::new().line(&x, &y, "test").to_svg();
assert!(
svg.contains("polyline"),
"expected polyline in SVG: {}",
&svg[..200]
);
}
#[test]
fn svg_contains_title() {
let svg = Plot::new()
.title("My Plot")
.line(&[1.0], &[1.0], "s")
.to_svg();
assert!(svg.contains("My Plot"), "title missing from SVG");
}
#[test]
fn svg_contains_circle_for_scatter() {
let svg = Plot::new()
.scatter(&[1.0, 2.0], &[3.0, 4.0], "pts")
.to_svg();
assert!(svg.contains("circle"), "expected circles in SVG");
}
#[test]
fn svg_contains_rect_for_bar() {
let svg = Plot::new()
.bar(&[1.0, 2.0, 3.0], &[0.5, -0.3, 0.8], "acf")
.to_svg();
assert!(svg.contains("<rect"), "expected rects in bar chart SVG");
}
#[test]
fn acf_convenience_constructor() {
let lags: Vec<usize> = (0..8).collect();
let vals = vec![1.0, 0.7, 0.5, 0.3, 0.1, 0.0, -0.1, -0.2];
let svg = Plot::acf(&lags, &vals, 0.31).to_svg();
assert!(svg.contains("ACF"));
}
#[test]
fn save_writes_file() {
let path = "/tmp/inferust_test_plot.svg";
Plot::new()
.line(&[0.0, 1.0], &[0.0, 1.0], "l")
.save(path)
.unwrap();
assert!(std::path::Path::new(path).exists());
std::fs::remove_file(path).ok();
}
#[test]
fn print_ascii_does_not_panic() {
let x: Vec<f64> = (0..15).map(|i| i as f64).collect();
let y: Vec<f64> = x.iter().map(|xi| xi.sin()).collect();
let p = Plot::new().line(&x, &y, "sin");
let _ = std::panic::catch_unwind(|| p.print_ascii());
}
}