use super::{GraphDef, ImageData};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Summary {
pub tag: String,
pub step: usize,
pub value: SummaryValue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SummaryValue {
Scalar(f32),
Histogram(HistogramData),
Image(ImageSummary),
Text(String),
Graph(GraphDef),
PrCurve(PrCurveData),
Audio(AudioSummary),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HistogramData {
pub count: usize,
pub sum: f64,
pub sum_squares: f64,
pub min: f32,
pub max: f32,
pub buckets: Vec<HistogramBucket>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HistogramBucket {
pub edge: f32,
pub count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageSummary {
pub height: u32,
pub width: u32,
pub channels: u32,
pub encoded_image: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrCurveData {
pub tp: Vec<usize>,
pub fp: Vec<usize>,
pub tn: Vec<usize>,
pub fn_: Vec<usize>,
pub precision: Vec<f32>,
pub recall: Vec<f32>,
pub thresholds: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioSummary {
pub sample_rate: u32,
pub audio_data: Vec<f32>,
pub content_type: String,
}
impl Summary {
pub fn scalar(tag: &str, value: f32, step: usize) -> Self {
Self {
tag: tag.to_string(),
step,
value: SummaryValue::Scalar(value),
}
}
pub fn histogram(tag: &str, values: &[f32], step: usize) -> Self {
let histogram_data = create_histogram(values);
Self {
tag: tag.to_string(),
step,
value: SummaryValue::Histogram(histogram_data),
}
}
pub fn image(tag: &str, image: &ImageData, step: usize) -> Self {
let encoded_image = encode_image_as_png(image);
let image_summary = ImageSummary {
height: image.height,
width: image.width,
channels: image.channels,
encoded_image,
};
Self {
tag: tag.to_string(),
step,
value: SummaryValue::Image(image_summary),
}
}
pub fn text(tag: &str, text: &str, step: usize) -> Self {
Self {
tag: tag.to_string(),
step,
value: SummaryValue::Text(text.to_string()),
}
}
pub fn graph(graph: &GraphDef) -> Self {
Self {
tag: "graph".to_string(),
step: 0,
value: SummaryValue::Graph(graph.clone()),
}
}
pub fn pr_curve(tag: &str, labels: &[bool], predictions: &[f32], step: usize) -> Self {
let pr_data = create_pr_curve(labels, predictions);
Self {
tag: tag.to_string(),
step,
value: SummaryValue::PrCurve(pr_data),
}
}
pub fn audio(tag: &str, audio_data: &[f32], sample_rate: u32, step: usize) -> Self {
let audio_summary = AudioSummary {
sample_rate,
audio_data: audio_data.to_vec(),
content_type: "audio/wav".to_string(),
};
Self {
tag: tag.to_string(),
step,
value: SummaryValue::Audio(audio_summary),
}
}
}
fn create_histogram(values: &[f32]) -> HistogramData {
if values.is_empty() {
return HistogramData {
count: 0,
sum: 0.0,
sum_squares: 0.0,
min: 0.0,
max: 0.0,
buckets: Vec::new(),
};
}
let count = values.len();
let sum: f64 = values.iter().map(|&x| x as f64).sum();
let sum_squares: f64 = values.iter().map(|&x| (x as f64).powi(2)).sum();
let min = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let num_buckets = 30;
let bucket_width = (max - min) / num_buckets as f32;
let mut buckets = Vec::new();
for i in 0..num_buckets {
let edge = min + (i + 1) as f32 * bucket_width;
let count = values
.iter()
.filter(|&&x| x <= edge && (i == 0 || x > min + i as f32 * bucket_width))
.count();
buckets.push(HistogramBucket { edge, count });
}
HistogramData {
count,
sum,
sum_squares,
min,
max,
buckets,
}
}
fn create_pr_curve(labels: &[bool], predictions: &[f32]) -> PrCurveData {
let mut sorted_pairs: Vec<_> = labels.iter().zip(predictions.iter()).enumerate().collect();
sorted_pairs.sort_by(|a, b| b.1 .1.partial_cmp(a.1 .1).unwrap());
let total_positive = labels.iter().filter(|&&x| x).count();
let total_negative = labels.len() - total_positive;
let mut tp = Vec::new();
let mut fp = Vec::new();
let mut tn = Vec::new();
let mut fn_ = Vec::new();
let mut precision = Vec::new();
let mut recall = Vec::new();
let mut thresholds = Vec::new();
let mut true_positives = 0;
let mut false_positives = 0;
for (_i, &(_, (&label, &pred))) in sorted_pairs.iter().enumerate() {
if label {
true_positives += 1;
} else {
false_positives += 1;
}
let true_negatives = total_negative - false_positives;
let false_negatives = total_positive - true_positives;
let prec = if true_positives + false_positives > 0 {
true_positives as f32 / (true_positives + false_positives) as f32
} else {
1.0
};
let rec = if total_positive > 0 {
true_positives as f32 / total_positive as f32
} else {
0.0
};
tp.push(true_positives);
fp.push(false_positives);
tn.push(true_negatives);
fn_.push(false_negatives);
precision.push(prec);
recall.push(rec);
thresholds.push(pred);
}
PrCurveData {
tp,
fp,
tn,
fn_,
precision,
recall,
thresholds,
}
}
fn encode_image_as_png(image: &ImageData) -> Vec<u8> {
image.data.clone()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_summary() {
let summary = Summary::scalar("test_loss", 0.5, 100);
assert_eq!(summary.tag, "test_loss");
assert_eq!(summary.step, 100);
if let SummaryValue::Scalar(value) = summary.value {
assert_eq!(value, 0.5);
} else {
panic!("Expected scalar value");
}
}
#[test]
fn test_histogram_creation() {
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let histogram = create_histogram(&values);
assert_eq!(histogram.count, 5);
assert_eq!(histogram.sum, 15.0);
assert_eq!(histogram.min, 1.0);
assert_eq!(histogram.max, 5.0);
assert!(!histogram.buckets.is_empty());
}
#[test]
fn test_pr_curve_creation() {
let labels = vec![true, false, true, false, true];
let predictions = vec![0.9, 0.1, 0.8, 0.3, 0.7];
let pr_curve = create_pr_curve(&labels, &predictions);
assert_eq!(pr_curve.precision.len(), predictions.len());
assert_eq!(pr_curve.recall.len(), predictions.len());
assert_eq!(pr_curve.thresholds.len(), predictions.len());
}
}