use chrono::prelude::{DateTime, Local};
use std::{
collections::{
hash_map::{IntoIter, Iter, Keys},
HashMap,
},
convert::Into,
iter::IntoIterator,
path::Path,
};
use tensorboard_rs::summary_writer::SummaryWriter;
use crate::error::LrrError;
#[derive(Debug, Clone)]
pub enum RecordValue {
Scalar(f32),
DateTime(DateTime<Local>),
Array1(Vec<f32>),
Array2(Vec<f32>, [usize; 2]),
Array3(Vec<f32>, [usize; 3]),
String(String),
}
#[derive(Debug)]
pub struct Record(HashMap<String, RecordValue>);
impl Record {
pub fn empty() -> Self {
Self { 0: HashMap::new() }
}
pub fn from_slice<K: Into<String> + Clone>(s: &[(K, RecordValue)]) -> Self {
Self(
s.iter()
.map(|(k, v)| (k.clone().into(), v.clone()))
.collect(),
)
}
pub fn keys(&self) -> Keys<String, RecordValue> {
self.0.keys()
}
pub fn insert(&mut self, k: impl Into<String>, v: RecordValue) {
self.0.insert(k.into(), v);
}
pub fn iter(&self) -> Iter<'_, String, RecordValue> {
self.0.iter()
}
pub fn into_iter_in_record(self) -> IntoIter<String, RecordValue> {
self.0.into_iter()
}
pub fn get(&self, k: &str) -> Option<&RecordValue> {
self.0.get(k)
}
pub fn merge(self, record: Record) -> Self {
Record(self.0.into_iter().chain(record.0).collect())
}
pub fn get_scalar(&self, k: &str) -> Result<f32, LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::Scalar(v) => Ok(*v as _),
_ => Err(LrrError::RecordValueTypeError("Scalar".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
pub fn get_array1(&self, k: &str) -> Result<Vec<f32>, LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::Array1(v) => Ok(v.clone()),
_ => Err(LrrError::RecordValueTypeError("Array1".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
pub fn get_array2(&self, k: &str) -> Result<(Vec<f32>, [usize; 2]), LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::Array2(v, s) => Ok((v.clone(), s.clone())),
_ => Err(LrrError::RecordValueTypeError("Array2".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
pub fn get_array3(&self, k: &str) -> Result<(Vec<f32>, [usize; 3]), LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::Array3(v, s) => Ok((v.clone(), s.clone())),
_ => Err(LrrError::RecordValueTypeError("Array3".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
pub fn get_string(&self, k: &str) -> Result<String, LrrError> {
if let Some(v) = self.0.get(k) {
match v {
RecordValue::String(s) => Ok(s.clone()),
_ => Err(LrrError::RecordValueTypeError("String".to_string())),
}
} else {
Err(LrrError::RecordKeyError(k.to_string()))
}
}
}
pub trait Recorder {
fn write(&mut self, record: Record);
}
pub struct NullRecorder {}
impl NullRecorder {}
impl Recorder for NullRecorder {
fn write(&mut self, _record: Record) {}
}
pub struct TensorboardRecorder {
writer: SummaryWriter,
step_key: String,
ignore_unsupported_value: bool,
}
impl TensorboardRecorder {
pub fn new<P: AsRef<Path>>(logdir: P) -> Self {
Self {
writer: SummaryWriter::new(logdir),
step_key: "opt_steps".to_string(),
ignore_unsupported_value: true,
}
}
pub fn new_with_check_unsupported_value<P: AsRef<Path>>(logdir: P) -> Self {
Self {
writer: SummaryWriter::new(logdir),
step_key: "opt_steps".to_string(),
ignore_unsupported_value: false,
}
}
}
impl Recorder for TensorboardRecorder {
fn write(&mut self, record: Record) {
let step = match record.get(&self.step_key).unwrap() {
RecordValue::Scalar(v) => *v as usize,
_ => {
panic!()
}
};
for (k, v) in record.iter() {
if *k != self.step_key {
match v {
RecordValue::Scalar(v) => self.writer.add_scalar(k, *v as f32, step),
RecordValue::DateTime(_) => {} RecordValue::Array2(data, shape) => {
let shape = [3, shape[0], shape[1]];
let min = data.iter().fold(f32::MAX, |m, v| v.min(m));
let scale = data.iter().fold(-f32::MAX, |m, v| v.max(m)) - min;
let mut data = data
.iter()
.map(|&e| ((e - min) / scale * 255f32) as u8)
.collect::<Vec<_>>();
let data_ = data.clone();
data.extend(data_.iter());
data.extend(data_.iter());
self.writer.add_image(k, data.as_slice(), &shape, step)
}
_ => {
if !self.ignore_unsupported_value {
panic!("Unsupported value: {:?}", (k, v));
}
}
};
}
}
}
}
#[derive(Default)]
pub struct BufferedRecorder(Vec<Record>);
impl BufferedRecorder {
pub fn new() -> Self {
Self(Vec::default())
}
pub fn iter(&self) -> std::slice::Iter<Record> {
self.0.iter()
}
}
impl Recorder for BufferedRecorder {
fn write(&mut self, record: Record) {
self.0.push(record);
}
}