use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::{self, File};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
pub mod event_writer;
pub mod summary;
use self::event_writer::EventWriter;
use self::summary::Summary;
pub struct SummaryWriter {
log_dir: PathBuf,
event_writer: EventWriter,
global_step: usize,
flush_interval: usize,
pending_summaries: Vec<Summary>,
}
impl SummaryWriter {
pub fn new(log_dir: impl AsRef<Path>) -> std::io::Result<Self> {
let log_dir = log_dir.as_ref().to_path_buf();
fs::create_dir_all(&log_dir)?;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
#[cfg(not(target_arch = "wasm32"))]
let hostname = hostname::get()
.unwrap_or_else(|_| std::ffi::OsString::from("unknown"))
.to_string_lossy()
.to_string();
#[cfg(target_arch = "wasm32")]
let hostname = "wasm-browser".to_string();
let filename = format!("events.out.tfevents.{}.{}", timestamp, hostname);
let event_path = log_dir.join(filename);
let event_writer = EventWriter::new(event_path)?;
Ok(Self {
log_dir,
event_writer,
global_step: 0,
flush_interval: 10,
pending_summaries: Vec::new(),
})
}
pub fn add_scalar(&mut self, tag: &str, value: f32, step: Option<usize>) {
let step = step.unwrap_or(self.global_step);
let summary = Summary::scalar(tag, value, step);
self.pending_summaries.push(summary);
if self.pending_summaries.len() >= self.flush_interval {
self.flush();
}
}
pub fn add_histogram(&mut self, tag: &str, values: &[f32], step: Option<usize>) {
let step = step.unwrap_or(self.global_step);
let summary = Summary::histogram(tag, values, step);
self.pending_summaries.push(summary);
if self.pending_summaries.len() >= self.flush_interval {
self.flush();
}
}
pub fn add_image(&mut self, tag: &str, image: &ImageData, step: Option<usize>) {
let step = step.unwrap_or(self.global_step);
let summary = Summary::image(tag, image, step);
self.pending_summaries.push(summary);
if self.pending_summaries.len() >= self.flush_interval {
self.flush();
}
}
pub fn add_text(&mut self, tag: &str, text: &str, step: Option<usize>) {
let step = step.unwrap_or(self.global_step);
let summary = Summary::text(tag, text, step);
self.pending_summaries.push(summary);
if self.pending_summaries.len() >= self.flush_interval {
self.flush();
}
}
pub fn add_graph(&mut self, graph: &GraphDef) {
let summary = Summary::graph(graph);
self.pending_summaries.push(summary);
self.flush();
}
pub fn add_embedding(
&mut self,
mat: &[Vec<f32>],
metadata: Option<Vec<String>>,
tag: Option<&str>,
) -> std::io::Result<()> {
let tag = tag.unwrap_or("default");
let projector_dir = self.log_dir.join("projector");
fs::create_dir_all(&projector_dir)?;
let tensor_path = projector_dir.join(format!("{}_tensor.tsv", tag));
let mut tensor_file = File::create(tensor_path)?;
for vec in mat {
let line: String = vec
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join("\t");
writeln!(tensor_file, "{}", line)?;
}
if let Some(metadata) = metadata {
let metadata_path = projector_dir.join(format!("{}_metadata.tsv", tag));
let mut metadata_file = File::create(metadata_path)?;
for label in metadata {
writeln!(metadata_file, "{}", label)?;
}
}
let config = ProjectorConfig {
embeddings: vec![EmbeddingInfo {
tensor_name: format!("{}_tensor.tsv", tag),
metadata_path: Some(format!("{}_metadata.tsv", tag)),
}],
};
let config_path = projector_dir.join("projector_config.pbtxt");
let config_content = format_projector_config(&config);
fs::write(config_path, config_content)?;
Ok(())
}
pub fn add_pr_curve(
&mut self,
tag: &str,
labels: &[bool],
predictions: &[f32],
step: Option<usize>,
) {
let step = step.unwrap_or(self.global_step);
let summary = Summary::pr_curve(tag, labels, predictions, step);
self.pending_summaries.push(summary);
if self.pending_summaries.len() >= self.flush_interval {
self.flush();
}
}
pub fn flush(&mut self) {
for summary in self.pending_summaries.drain(..) {
self.event_writer.write_summary(summary);
}
self.event_writer.flush();
}
pub fn step(&mut self) {
self.global_step += 1;
}
pub fn close(mut self) {
self.flush();
}
}
#[derive(Debug, Clone)]
pub struct ImageData {
pub height: u32,
pub width: u32,
pub channels: u32,
pub data: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphDef {
pub nodes: Vec<NodeDef>,
pub edges: Vec<EdgeDef>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeDef {
pub name: String,
pub op: String,
pub inputs: Vec<String>,
pub attrs: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeDef {
pub source: String,
pub target: String,
pub label: Option<String>,
}
#[derive(Debug, Clone)]
struct ProjectorConfig {
embeddings: Vec<EmbeddingInfo>,
}
#[derive(Debug, Clone)]
struct EmbeddingInfo {
tensor_name: String,
metadata_path: Option<String>,
}
fn format_projector_config(config: &ProjectorConfig) -> String {
let mut result = String::new();
for embedding in &config.embeddings {
result.push_str("embeddings {\n");
result.push_str(&format!(" tensor_name: \"{}\"\n", embedding.tensor_name));
if let Some(metadata_path) = &embedding.metadata_path {
result.push_str(&format!(" metadata_path: \"{}\"\n", metadata_path));
}
result.push_str("}\n");
}
result
}
pub mod python_compat {
use super::*;
pub fn create_writer(base_dir: Option<&str>) -> std::io::Result<SummaryWriter> {
let base_dir = base_dir.unwrap_or("runs");
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let log_dir = format!("{}/experiment_{}", base_dir, timestamp);
SummaryWriter::new(log_dir)
}
pub fn log_scalar(writer: &mut SummaryWriter, tag: &str, value: f32) {
writer.add_scalar(tag, value, None);
writer.step();
}
}
#[macro_export]
macro_rules! tb_log {
($writer:expr, scalar: $tag:expr, $value:expr) => {
$writer.add_scalar($tag, $value, None);
};
($writer:expr, histogram: $tag:expr, $values:expr) => {
$writer.add_histogram($tag, $values, None);
};
($writer:expr, text: $tag:expr, $text:expr) => {
$writer.add_text($tag, $text, None);
};
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_summary_writer_creation() {
let dir = tempdir().unwrap();
let _writer = SummaryWriter::new(dir.path()).unwrap();
let entries: Vec<_> = fs::read_dir(dir.path())
.unwrap()
.filter_map(Result::ok)
.collect();
assert_eq!(entries.len(), 1);
assert!(entries[0]
.file_name()
.to_string_lossy()
.starts_with("events.out.tfevents"));
}
#[test]
fn test_scalar_logging() {
let dir = tempdir().unwrap();
let mut writer = SummaryWriter::new(dir.path()).unwrap();
writer.add_scalar("loss", 0.5, Some(0));
writer.add_scalar("accuracy", 0.95, Some(0));
writer.flush();
let entries: Vec<_> = fs::read_dir(dir.path())
.unwrap()
.filter_map(Result::ok)
.collect();
assert_eq!(entries.len(), 1);
assert!(entries[0]
.file_name()
.to_string_lossy()
.starts_with("events.out.tfevents"));
let file_path = entries[0].path();
let file_contents = fs::read(&file_path).unwrap();
assert!(
file_contents.len() > 0,
"Event file should contain data after flush(), but was empty on platform: {}",
std::env::consts::OS
);
}
#[test]
fn test_histogram_logging() {
let dir = tempdir().unwrap();
let mut writer = SummaryWriter::new(dir.path()).unwrap();
let values = vec![0.1, 0.2, 0.3, 0.4, 0.5];
writer.add_histogram("weights", &values, Some(0));
writer.flush();
}
}