use eframe::egui;
use egui_plot::{Legend, Line, Plot, PlotPoint, PlotPoints, Text};
pub struct ShowConfig {
title: Option<String>,
legend: bool,
log_y: bool,
}
impl ShowConfig {
#[must_use]
pub fn new() -> Self {
Self {
title: None,
legend: false,
log_y: false,
}
}
#[must_use]
pub fn title(mut self, title: impl Into<String>) -> Self {
self.title = Some(title.into());
self
}
#[must_use]
pub fn legend(mut self) -> Self {
self.legend = true;
self
}
#[must_use]
pub fn log_y(mut self) -> Self {
self.log_y = true;
self
}
}
impl Default for ShowConfig {
fn default() -> Self {
Self::new()
}
}
pub struct PlotObserver<const N: usize> {
names: [String; N],
data: [Vec<[f64; 2]>; N],
labels: Vec<(f64, f64, String)>,
label_size: f32,
}
impl<const N: usize> PlotObserver<N> {
pub fn new(names: [&str; N]) -> Self {
Self {
names: names.map(str::to_owned),
data: std::array::from_fn(|_| Vec::new()),
labels: Vec::new(),
label_size: 14.0,
}
}
pub fn record(&mut self, x: f64, traces: [Option<f64>; N]) {
for (i, y) in traces.into_iter().enumerate() {
if let Some(y) = y {
self.data[i].push([x, y]);
}
}
}
pub fn label(&mut self, x: f64, y: f64, text: impl Into<String>) {
self.labels.push((x, y, text.into()));
}
pub fn label_size(&mut self, size: f32) -> &mut Self {
self.label_size = size;
self
}
pub fn show(self, config: ShowConfig) -> Result<(), eframe::Error> {
let options = eframe::NativeOptions::default();
let title = config.title.unwrap_or_default();
let traces: Vec<(String, Vec<[f64; 2]>)> = self.names.into_iter().zip(self.data).collect();
eframe::run_native(
&title,
options,
Box::new(move |_cc| {
Ok(Box::new(PlotApp {
traces,
labels: self.labels,
label_size: self.label_size,
legend: config.legend,
log_y: config.log_y,
plot_rect: None,
}))
}),
)
}
}
#[derive(Clone, Copy)]
enum Gutter {
Y,
X,
}
struct PlotApp {
traces: Vec<(String, Vec<[f64; 2]>)>,
labels: Vec<(f64, f64, String)>,
label_size: f32,
legend: bool,
log_y: bool,
plot_rect: Option<egui::Rect>,
}
impl eframe::App for PlotApp {
fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) {
egui::CentralPanel::default().show(ctx, |ui| {
let gutter = self.plot_rect.and_then(|plot_rect| {
let cursor = ctx.input(|i| i.pointer.latest_pos())?;
if cursor.x < plot_rect.left()
&& (plot_rect.top()..=plot_rect.bottom()).contains(&cursor.y)
{
Some(Gutter::Y)
} else if cursor.y > plot_rect.bottom()
&& (plot_rect.left()..=plot_rect.right()).contains(&cursor.x)
{
Some(Gutter::X)
} else {
None
}
});
let scroll_delta = if gutter.is_some() {
ctx.input(|i| i.smooth_scroll_delta)
} else {
egui::Vec2::ZERO
};
let zoom = gutter.and_then(|g| {
if scroll_delta.y == 0.0 {
return None;
}
let f = (scroll_delta.y / 200.0).exp();
Some(match g {
Gutter::Y => egui::Vec2::new(1.0, f),
Gutter::X => egui::Vec2::new(f, 1.0),
})
});
let mut plot = Plot::new("plot_observer");
if self.legend {
plot = plot.legend(Legend::default());
}
if self.log_y {
plot = plot.y_axis_label("log₁₀");
}
if gutter.is_some() {
plot = plot.allow_scroll(false);
}
let log_y = self.log_y;
let label_size = self.label_size;
let response = plot.show(ui, |plot_ui| {
if let Some(factor) = zoom {
plot_ui.zoom_bounds_around_hovered(factor);
}
for (name, points) in &self.traces {
let plot_points: PlotPoints = if log_y {
points
.iter()
.filter(|p| p[1] > 0.0)
.map(|p| [p[0], p[1].log10()])
.collect()
} else {
points.iter().copied().collect()
};
plot_ui.line(Line::new(plot_points).name(name));
}
for (x, y, text) in &self.labels {
plot_ui.text(
Text::new(
PlotPoint::new(*x, *y),
egui::RichText::new(text).size(label_size),
)
.anchor(egui::Align2::LEFT_BOTTOM),
);
}
*plot_ui.transform().frame()
});
self.plot_rect = Some(response.inner);
});
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_observer() -> PlotObserver<2> {
PlotObserver::new(["a", "b"])
}
fn points(obs: &PlotObserver<2>, trace: usize) -> &[[f64; 2]] {
&obs.data[trace]
}
#[test]
fn records_point_when_both_traces_are_some() {
let mut obs = make_observer();
obs.record(1.0, [Some(2.0), Some(3.0)]);
assert_eq!(points(&obs, 0), [[1.0, 2.0]]);
assert_eq!(points(&obs, 1), [[1.0, 3.0]]);
}
#[test]
fn skips_only_affected_trace_when_y_is_none() {
let mut obs = make_observer();
obs.record(1.0, [None, Some(3.0)]);
assert!(points(&obs, 0).is_empty());
assert_eq!(points(&obs, 1), [[1.0, 3.0]]);
}
#[test]
fn accumulates_points_across_multiple_calls() {
let mut obs = make_observer();
obs.record(1.0, [Some(10.0), Some(20.0)]);
obs.record(2.0, [Some(11.0), Some(21.0)]);
obs.record(3.0, [Some(12.0), Some(22.0)]);
assert_eq!(points(&obs, 0), [[1.0, 10.0], [2.0, 11.0], [3.0, 12.0]]);
assert_eq!(points(&obs, 1), [[1.0, 20.0], [2.0, 21.0], [3.0, 22.0]]);
}
#[test]
fn independent_x_values_per_trace() {
let mut obs = make_observer();
obs.record(1.0, [Some(10.0), None]);
obs.record(2.0, [None, Some(20.0)]);
assert_eq!(points(&obs, 0), [[1.0, 10.0]]);
assert_eq!(points(&obs, 1), [[2.0, 20.0]]);
}
#[test]
fn accumulates_labels_across_multiple_calls() {
let mut obs = make_observer();
obs.label(1.0, 2.0, "a");
obs.label(3.0, 4.0, "b");
assert_eq!(
obs.labels,
vec![(1.0, 2.0, "a".to_owned()), (3.0, 4.0, "b".to_owned())]
);
}
}