use std::path::{Path, PathBuf};
use anyhow::{Result, anyhow};
use clap::Args;
use plotters::prelude::*;
use biodream::{Channel, Marker};
use super::info::read_acq;
const PALETTE: &[RGBColor] = &[
RGBColor(31, 119, 180),
RGBColor(255, 127, 14),
RGBColor(44, 160, 44),
RGBColor(214, 39, 40),
RGBColor(148, 103, 189),
RGBColor(140, 86, 75),
RGBColor(227, 119, 194),
RGBColor(127, 127, 127),
RGBColor(188, 189, 34),
RGBColor(23, 190, 207),
];
fn ch_color(idx: usize) -> RGBColor {
PALETTE
.get(idx % PALETTE.len())
.copied()
.unwrap_or(RGBColor(127, 127, 127))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum ImageFormat {
Png,
Svg,
}
#[derive(Debug, Args)]
pub struct PlotArgs {
#[arg(value_name = "FILE")]
pub path: PathBuf,
#[arg(short, long, value_name = "PATH")]
pub output: Option<PathBuf>,
#[arg(long, value_enum)]
pub format: Option<ImageFormat>,
#[arg(long, default_value_t = 1200)]
pub width: u32,
#[arg(long, default_value_t = 200, value_name = "HEIGHT")]
pub height_per_channel: u32,
#[arg(long, value_delimiter = ',', value_name = "CH")]
pub channels: Option<Vec<String>>,
#[arg(long, value_name = "SECS")]
pub start: Option<f64>,
#[arg(long, value_name = "SECS")]
pub end: Option<f64>,
}
pub fn run(args: &PlotArgs) -> Result<()> {
let result = read_acq(&args.path)?;
let df = result.value;
let selected: Vec<(usize, &Channel)> = match &args.channels {
None => df.channels.iter().enumerate().collect(),
Some(specs) => {
let mut out = Vec::new();
for spec in specs {
if let Ok(idx) = spec.parse::<usize>() {
let ch = df.channels.get(idx).ok_or_else(|| {
anyhow!(
"channel index {idx} is out of range \
(file has {} channels)",
df.channels.len()
)
})?;
out.push((idx, ch));
} else {
let (idx, ch) = df
.channels
.iter()
.enumerate()
.find(|(_, c)| c.name == *spec)
.ok_or_else(|| anyhow!("channel '{spec}' not found"))?;
out.push((idx, ch));
}
}
out
}
};
if selected.is_empty() {
return Err(anyhow!("no channels to plot"));
}
let out_path = resolve_output(&args.path, args.output.as_deref(), args.format);
let fmt = infer_format(&out_path, args.format);
let n_panels =
u32::try_from(selected.len()).map_err(|_| anyhow!("too many channels to plot"))?;
let total_height = n_panels
.checked_mul(args.height_per_channel)
.ok_or_else(|| anyhow!("image dimensions overflow u32"))?;
let base_rate = df.metadata.samples_per_second;
match fmt {
ImageFormat::Png => {
let root =
BitMapBackend::new(&out_path, (args.width, total_height)).into_drawing_area();
render(
&root,
&selected,
&df.markers,
base_rate,
args.start,
args.end,
)
.map_err(|e| anyhow!("PNG render error: {e:?}"))?;
root.present()
.map_err(|e| anyhow!("PNG write error: {e:?}"))?;
}
ImageFormat::Svg => {
let root = SVGBackend::new(&out_path, (args.width, total_height)).into_drawing_area();
render(
&root,
&selected,
&df.markers,
base_rate,
args.start,
args.end,
)
.map_err(|e| anyhow!("SVG render error: {e:?}"))?;
root.present()
.map_err(|e| anyhow!("SVG write error: {e:?}"))?;
}
}
eprintln!("wrote {}", out_path.display());
Ok(())
}
fn resolve_output(input: &Path, output: Option<&Path>, format: Option<ImageFormat>) -> PathBuf {
if let Some(p) = output {
return p.to_owned();
}
let stem = input
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("output");
let ext = match format {
Some(ImageFormat::Svg) => "svg",
_ => "png",
};
PathBuf::from(format!("{stem}.{ext}"))
}
fn infer_format(path: &Path, forced: Option<ImageFormat>) -> ImageFormat {
if let Some(f) = forced {
return f;
}
match path.extension().and_then(|e| e.to_str()) {
Some("svg") => ImageFormat::Svg,
_ => ImageFormat::Png,
}
}
#[expect(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
reason = "t_sec * rate is non-negative and bounded by the recording duration"
)]
fn sec_to_idx(t_sec: f64, rate: f64, max: usize) -> usize {
((t_sec * rate) as usize).min(max)
}
#[expect(
clippy::cast_precision_loss,
reason = "sample indices are bounded by point_count; f64 precision is sufficient for time-axis display"
)]
fn idx_to_sec(idx: usize, rate: f64) -> f64 {
idx as f64 / rate
}
const MAX_PLOT_POINTS: usize = 8_000;
fn render<DB: DrawingBackend>(
root: &DrawingArea<DB, plotters::coord::Shift>,
channels: &[(usize, &Channel)],
markers: &[Marker],
base_rate: f64,
start_sec: Option<f64>,
end_sec: Option<f64>,
) -> Result<(), DrawingAreaErrorKind<DB::ErrorType>> {
root.fill(&WHITE)?;
let panels = root.split_evenly((channels.len(), 1));
let last_idx = channels.len().saturating_sub(1);
for (panel_idx, (panel, &(palette_idx, ch))) in panels.iter().zip(channels.iter()).enumerate() {
let rate = ch.samples_per_second;
let n_total = ch.point_count;
let t_start = start_sec.unwrap_or(0.0_f64).max(0.0_f64);
let t_end_max = idx_to_sec(n_total, rate);
let t_end = end_sec.unwrap_or(t_end_max).min(t_end_max);
let t_end = if t_end <= t_start {
t_start + 0.001
} else {
t_end
};
let i_start = sec_to_idx(t_start, rate, n_total);
let i_end = (sec_to_idx(t_end, rate, n_total) + 1).min(n_total);
let samples = ch.scaled_samples();
let slice = samples.get(i_start..i_end).unwrap_or(&[]);
let (y_min, y_max) = y_range(slice);
let is_last = panel_idx == last_idx;
let caption = format!("{} [{}]", ch.name, ch.units);
let mut chart = ChartBuilder::on(panel)
.caption(caption, ("sans-serif", 14).into_font())
.margin(4_u32)
.x_label_area_size(if is_last { 30_u32 } else { 0_u32 })
.y_label_area_size(65_u32)
.build_cartesian_2d(t_start..t_end, y_min..y_max)?;
chart
.configure_mesh()
.x_labels(if is_last { 5 } else { 0 })
.y_labels(4)
.x_desc(if is_last { "Time (s)" } else { "" })
.draw()?;
let stride = (slice.len() / MAX_PLOT_POINTS).max(1);
let color = ch_color(palette_idx);
chart.draw_series(LineSeries::new(
slice
.iter()
.enumerate()
.step_by(stride)
.map(|(i, &v)| (idx_to_sec(i_start + i, rate), v)),
color,
))?;
for m in markers {
let t_m = idx_to_sec(m.global_sample_index, base_rate);
if t_m < t_start || t_m > t_end {
continue;
}
chart.draw_series(std::iter::once(PathElement::new(
vec![(t_m, y_min), (t_m, y_max)],
RED,
)))?;
}
}
Ok(())
}
fn y_range(samples: &[f64]) -> (f64, f64) {
if samples.is_empty() {
return (-1.0, 1.0);
}
let mut lo = f64::INFINITY;
let mut hi = f64::NEG_INFINITY;
for &v in samples {
if !v.is_finite() {
continue;
}
if v < lo {
lo = v;
}
if v > hi {
hi = v;
}
}
if !lo.is_finite() {
return (-1.0, 1.0);
}
if (hi - lo).abs() < f64::EPSILON {
return (lo - 1.0, hi + 1.0);
}
let pad = (hi - lo) * 0.05;
(lo - pad, hi + pad)
}