use std::collections::HashMap;
use crate::error::{NeuralError, Result};
#[derive(Debug, Clone)]
pub struct PlotOptions {
pub width: usize,
pub height: usize,
}
impl Default for PlotOptions {
fn default() -> Self {
Self {
width: 80,
height: 20,
}
}
}
const BLOCK_CHARS: [char; 9] = [' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'];
pub fn ascii_plot(data: &[f64], width: usize, height: usize) -> String {
if data.is_empty() || height == 0 {
return String::new();
}
let n_cols = width.min(data.len()).max(1);
let sampled: Vec<f64> = (0..n_cols)
.map(|col| {
let start = col * data.len() / n_cols;
let end = ((col + 1) * data.len() / n_cols)
.max(start + 1)
.min(data.len());
let slice = &data[start..end];
slice.iter().copied().sum::<f64>() / (slice.len() as f64)
})
.collect();
let min_val = sampled.iter().cloned().fold(f64::INFINITY, f64::min);
let max_val = sampled.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let range = max_val - min_val;
let total_steps = height * 8;
let bar_steps: Vec<usize> = sampled
.iter()
.map(|&v| {
if range < 1e-12 {
total_steps / 2
} else {
(((v - min_val) / range) * total_steps as f64).round() as usize
}
})
.collect();
let mut rows: Vec<String> = Vec::with_capacity(height);
for row in 0..height {
let band_bottom = (height - 1 - row) * 8;
let band_top = band_bottom + 8;
let row_str: String = bar_steps
.iter()
.map(|&steps| {
if steps >= band_top {
BLOCK_CHARS[8]
} else if steps > band_bottom {
let partial = steps - band_bottom; BLOCK_CHARS[partial]
} else {
' '
}
})
.collect();
rows.push(row_str);
}
rows.join("\n")
}
pub fn plot_metrics(
history: &HashMap<String, Vec<f64>>,
title: Option<&str>,
options: Option<PlotOptions>,
) -> Result<String> {
let opts = options.unwrap_or_default();
let mut output = String::new();
if let Some(t) = title {
let border = "─".repeat(opts.width.min(80));
output.push_str(&format!("┌{border}┐\n"));
output.push_str(&format!("│ {t:<width$}│\n", width = opts.width.min(80) - 1));
output.push_str(&format!("└{border}┘\n"));
}
let mut keys: Vec<&String> = history.keys().collect();
keys.sort();
for key in keys {
let values = &history[key];
if values.is_empty() {
continue;
}
output.push_str(&format!(" {key}:\n"));
let plot = ascii_plot(values, opts.width, opts.height.clamp(4, 10));
for line in plot.lines() {
output.push_str(&format!(" │{line}│\n"));
}
output.push('\n');
}
Ok(output)
}
pub fn float_vec_to_f64<F: scirs2_core::numeric::Float>(values: &[F]) -> Result<Vec<f64>> {
let converted: Vec<f64> = values.iter().filter_map(|v| v.to_f64()).collect();
if converted.is_empty() && !values.is_empty() {
return Err(NeuralError::ComputationError(
"Could not convert any metric values to f64 for visualization".to_string(),
));
}
Ok(converted)
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn test_ascii_plot_sinusoid() {
let data: Vec<f64> = (0..32).map(|i| (i as f64 * PI / 16.0).sin()).collect();
let plot = ascii_plot(&data, 32, 8);
assert_eq!(plot.lines().count(), 8, "expected 8 rows");
let has_block = plot
.chars()
.any(|c| matches!(c, '▁' | '▂' | '▃' | '▄' | '▅' | '▆' | '▇' | '█'));
assert!(has_block, "plot should contain block characters");
assert!(!plot.is_empty());
}
#[test]
fn test_ascii_plot_empty_returns_empty() {
assert_eq!(ascii_plot(&[], 10, 5), "");
}
#[test]
fn test_ascii_plot_zero_height_returns_empty() {
assert_eq!(ascii_plot(&[1.0, 2.0, 3.0], 10, 0), "");
}
#[test]
fn test_ascii_plot_constant_series() {
let data = vec![5.0_f64; 20];
let plot = ascii_plot(&data, 20, 4);
assert_eq!(plot.lines().count(), 4);
}
#[test]
fn test_plot_metrics_produces_labelled_output() {
let mut history: HashMap<String, Vec<f64>> = HashMap::new();
history.insert("train_loss".to_string(), vec![1.0, 0.8, 0.6, 0.5, 0.4]);
history.insert("val_loss".to_string(), vec![1.1, 0.9, 0.7, 0.6, 0.5]);
let result = plot_metrics(
&history,
Some("Test Metrics"),
Some(PlotOptions {
width: 40,
height: 6,
}),
);
assert!(result.is_ok());
let output = result.expect("should succeed");
assert!(output.contains("train_loss"));
assert!(output.contains("val_loss"));
assert!(output.contains("Test Metrics"));
}
}